In [9]:
import torch
from torch.utils.data import DataLoader, Dataset, TensorDataset, ConcatDataset
from torchvision import transforms, datasets
from PIL import Image
import os
import zipfile
import torch.nn as nn
import numpy as np
from sklearn.model_selection import KFold
import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping
import torch.optim as optim

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [10]:
# Additional data augmentation techniques
class AddGaussianNoise(object):
    def __init__(self, mean=0., std=1.):
        self.std = std
        self.mean = mean
        
    def __call__(self, tensor):
        return tensor + torch.randn(tensor.size()) * self.std + self.mean
    
    def __repr__(self):
        return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std)

In [11]:
# Load experimental data
def load_all_experimental_data(test_digits_folder):
    train_images = []
    train_labels = []
    test_images = []
    test_labels = []
    participant_data = {}

    transform = transforms.Compose([
        transforms.Resize((16, 16)),
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
    ])

    for filename in os.listdir(test_digits_folder):
        if filename.endswith('.zip') and filename.startswith('experiment_results_participant'):
            participant_number = int(filename.split('participant')[1].split('.')[0])
            zip_filepath = os.path.join(test_digits_folder, filename)

            participant_train_images = []
            participant_train_labels = []
            participant_test_images = []
            participant_test_labels = []

            with zipfile.ZipFile(zip_filepath, 'r') as zip_ref:
                for img_filename in zip_ref.namelist():
                    if img_filename.endswith('.png'):
                        with zip_ref.open(img_filename) as file:
                            img = Image.open(file).convert('L')
                            img_tensor = transform(img)
                            
                            digit = int(img_filename.split('_')[0])
                            
                            if 'composite' in img_filename:
                                test_images.append(img_tensor)
                                test_labels.append(digit)
                                participant_test_images.append(img_tensor)
                                participant_test_labels.append(digit)
                            else:
                                train_images.append(img_tensor)
                                train_labels.append(digit)
                                participant_train_images.append(img_tensor)
                                participant_train_labels.append(digit)

            participant_data[participant_number] = {
                'train': (torch.stack(participant_train_images), torch.tensor(participant_train_labels)),
                'test': (torch.stack(participant_test_images), torch.tensor(participant_test_labels))
            }

    return (torch.stack(train_images), torch.tensor(train_labels), 
            torch.stack(test_images), torch.tensor(test_labels),
            participant_data)

In [12]:
def load_augmented_mnist():
    transform = transforms.Compose([
        transforms.Resize((16, 16)),
        transforms.ToTensor(),
        transforms.RandomApply([
            transforms.RandomRotation(10),
            transforms.RandomAffine(0, translate=(0.1, 0.1)),
            AddGaussianNoise(0., 0.05),
        ], p=0.5),
    ])
    
    invert_transform = transforms.Compose([
        transform,
        transforms.Lambda(lambda x: 1 - x),
    ])
    
    mnist_train = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
    mnist_test = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
    
    inverted_mnist_train = datasets.MNIST(root='./data', train=True, download=True, transform=invert_transform)
    inverted_mnist_test = datasets.MNIST(root='./data', train=False, download=True, transform=invert_transform)
    
    combined_train = ConcatDataset([mnist_train, inverted_mnist_train])
    combined_test = ConcatDataset([mnist_test, inverted_mnist_test])
    
    return combined_train, combined_test

In [13]:
class MixedDataset(Dataset):
    def __init__(self, dataset, transform=None):
        self.dataset = dataset
        self.transform = transform

    def __getitem__(self, index):
        if isinstance(self.dataset, TensorDataset):
            x, y = self.dataset[index]
        else:
            x, y = self.dataset[index]
        
        # Ensure x is a tensor
        if not isinstance(x, torch.Tensor):
            x = torch.tensor(x)
        
        # Ensure y is a tensor
        if not isinstance(y, torch.Tensor):
            y = torch.tensor(y)
        
        if self.transform:
            x = self.transform(x)
        return x, y

    def __len__(self):
        return len(self.dataset)

In [14]:
class LeNet5_16x16(pl.LightningModule):
    def __init__(self, num_classes=10, learning_rate=0.001):
        super(LeNet5_16x16, self).__init__()
        self.learning_rate = learning_rate
        self.layer1 = nn.Sequential(
            nn.Conv2d(1, 6, kernel_size=5),
            nn.Tanh(),
            nn.AvgPool2d(kernel_size=2)
        )
        self.layer2 = nn.Sequential(
            nn.Conv2d(6, 16, kernel_size=5),
            nn.Tanh(),
            nn.AvgPool2d(kernel_size=2)
        )
        self.fc1 = nn.Linear(16 * 1 * 1, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, num_classes)

    def forward(self, x):
        out = self.layer1(x)
        out = self.layer2(out)
        out = out.view(out.size(0), -1)
        out = self.fc1(out)
        out = self.fc2(out)
        out = self.fc3(out)
        return out

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = nn.CrossEntropyLoss()(y_hat, y)
        self.log('train_loss', loss)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = nn.CrossEntropyLoss()(y_hat, y)
        acc = (y_hat.argmax(dim=1) == y).float().mean()
        self.log('val_loss', loss)
        self.log('val_acc', acc)
        return {'val_loss': loss, 'val_acc': acc}

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.learning_rate)

In [15]:
class EnsembleDataModule(pl.LightningDataModule):
    def __init__(self, exp_data_path, batch_size=64):
        super().__init__()
        self.exp_data_path = exp_data_path
        self.batch_size = batch_size

    def setup(self, stage=None):
        exp_train_images, exp_train_labels, self.exp_test_images, self.exp_test_labels, _ = load_all_experimental_data(self.exp_data_path)
        exp_dataset = TensorDataset(exp_train_images, exp_train_labels)
        mnist_train, mnist_test = load_augmented_mnist()
        self.combined_dataset = ConcatDataset([MixedDataset(exp_dataset), MixedDataset(mnist_train)])
        self.mnist_test = MixedDataset(mnist_test)

    def train_dataloader(self):
        return DataLoader(self.combined_dataset, batch_size=self.batch_size, shuffle=True, num_workers=4, pin_memory=True)

    def val_dataloader(self):
        return DataLoader(self.mnist_test, batch_size=self.batch_size, shuffle=False, num_workers=4, pin_memory=True)

    def test_dataloader(self):
        exp_test_dataset = TensorDataset(self.exp_test_images, self.exp_test_labels)
        return DataLoader(exp_test_dataset, batch_size=self.batch_size, shuffle=False, num_workers=4, pin_memory=True)

def train_ensemble(num_models=25, num_epochs=30, patience=5):
    ensemble = []
    data_module = EnsembleDataModule('test_digits')

    for i in range(num_models):
        print(f"Training model {i+1}/{num_models}")

        model = LeNet5_16x16()
        early_stop_callback = EarlyStopping(monitor='val_loss', patience=patience)
        trainer = pl.Trainer(max_epochs=num_epochs, callbacks=[early_stop_callback], gpus=1 if torch.cuda.is_available() else 0)
        
        trainer.fit(model, data_module)
        
        ensemble.append(model)
        
        # Save the model
        torch.save(model.state_dict(), f"lenet5_trained_model_ensemble_{i+1}.pth")
        print(f"Model {i+1} saved as lenet5_trained_model_ensemble_{i+1}.pth")

    return ensemble

In [16]:
# Training Function
def train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs, patience):
    model = model.to(device)
    best_val_acc = 0
    epochs_no_improve = 0
    
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        
        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss:.4f}')
        
        # Validation
        model.eval()
        correct_val = 0
        total_val = 0
        with torch.no_grad():
            for val_images, val_labels in val_loader:
                val_images, val_labels = val_images.to(device), val_labels.to(device)
                outputs_val = model(val_images)
                _, predicted_val = torch.max(outputs_val.data, 1)
                total_val += val_labels.size(0)
                correct_val += (predicted_val == val_labels).sum().item()
        
        val_acc = 100 * correct_val / total_val
        print(f'Validation Accuracy: {val_acc:.2f}%')
        
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            epochs_no_improve = 0
        else:
            epochs_no_improve += 1
            if epochs_no_improve == patience:
                print(f"Early stopping triggered after {epoch+1} epochs")
                break
    
    return model

In [17]:
# K-Fold Cross Validation
def k_fold_cross_validation(dataset, num_folds=5, num_epochs=50, patience=5):
    kf = KFold(n_splits=num_folds, shuffle=True, random_state=42)
    fold_results = []
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    for fold, (train_index, val_index) in enumerate(kf.split(range(len(dataset)))):
        print(f"Fold {fold + 1}/{num_folds}")
        train_subset = torch.utils.data.Subset(dataset, train_index)
        val_subset = torch.utils.data.Subset(dataset, val_index)

        train_loader = DataLoader(train_subset, batch_size=32, shuffle=True)
        val_loader = DataLoader(val_subset, batch_size=32, shuffle=False)

        model = LeNet5_16x16(num_classes=10).to(device)
        
        # Calculate class weights for weighted cross-entropy
        labels = torch.tensor([dataset[i][1] for i in train_index])
        class_counts = torch.bincount(labels)
        class_weights = 1. / class_counts.float()
        class_weights = class_weights / class_weights.sum()
        criterion = nn.CrossEntropyLoss(weight=class_weights.to(device))
        
        optimizer = optim.Adam(model.parameters(), lr=0.001)

        model = train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs, patience)

        # Evaluate on validation set
        model.eval()
        correct_val = 0
        total_val = 0
        with torch.no_grad():
            for val_images, val_labels in val_loader:
                val_images, val_labels = val_images.to(device), val_labels.to(device)
                outputs_val = model(val_images)
                _, predicted_val = torch.max(outputs_val.data, 1)
                total_val += val_labels.size(0)
                correct_val += (predicted_val == val_labels).sum().item()
        
        val_accuracy = 100 * correct_val / total_val
        fold_results.append(val_accuracy)
        print(f'Fold {fold + 1} Validation Accuracy: {val_accuracy:.2f}%')

    return fold_results

In [18]:
def train_ensemble(num_models=25, num_folds=3, num_epochs=30, patience=3):
    ensemble = []
    
    for i in range(num_models):
        print(f"Training model {i+1}/{num_models}")
        
        # Load experimental data
        exp_train_images, exp_train_labels, exp_test_images, exp_test_labels, participant_data = load_all_experimental_data('test_digits')
        exp_dataset = TensorDataset(exp_train_images, exp_train_labels)

        # Load augmented MNIST data
        mnist_train, mnist_test = load_augmented_mnist()

        # Combine experimental data with augmented MNIST data
        combined_dataset = ConcatDataset([MixedDataset(exp_dataset), MixedDataset(mnist_train)])

        # Perform K-Fold Cross Validation
        fold_results = k_fold_cross_validation(combined_dataset, num_folds=num_folds, num_epochs=num_epochs, patience=patience)
        print(f"Model {i+1} - Average Validation Accuracy across folds: {np.mean(fold_results):.2f}%")

        # Train final model on all data
        final_train_loader = DataLoader(combined_dataset, batch_size=64, num_workers=4, pin_memory=True,shuffle=True)
        final_val_loader = DataLoader(MixedDataset(mnist_test), batch_size=64, num_workers=4, pin_memory=True,shuffle=False)

        model = LeNet5_16x16(num_classes=10).to(device)
        
        # Calculate class weights for the final model
        all_labels = []
        for dataset in combined_dataset.datasets:
            if isinstance(dataset, TensorDataset):
                all_labels.extend(dataset.tensors[1].tolist())
            else:
                all_labels.extend([label for _, label in dataset])
        
        labels = torch.tensor(all_labels)
        class_counts = torch.bincount(labels)
        class_weights = 1. / class_counts.float()
        class_weights = class_weights / class_weights.sum()
        
        criterion = nn.CrossEntropyLoss(weight=class_weights.to(device))
        optimizer = optim.Adam(model.parameters(), lr=0.001)

        model = train_model(model, final_train_loader, final_val_loader, criterion, optimizer, num_epochs=num_epochs, patience=patience)
        
        ensemble.append(model)
        
        # Save the model
        torch.save(model.state_dict(), f"lenet5_trained_model_ensemble_{i+1}.pth")
        print(f"Model {i+1} saved as lenet5_trained_model_ensemble_{i+1}.pth")
    
    return ensemble

def ensemble_predict(ensemble, input_tensor):
    predictions = []
    for model in ensemble:
        model.eval()
        with torch.no_grad():
            output = model(input_tensor)
            predictions.append(torch.softmax(output, dim=1))
    avg_prediction = torch.mean(torch.stack(predictions), dim=0)
    return torch.argmax(avg_prediction, dim=1)

if __name__ == "__main__":
    ensemble = train_ensemble(num_models=25, num_epochs=30, patience=5)

    # Test the ensemble on experimental data
    data_module = EnsembleDataModule('test_digits')
    data_module.setup()
    test_loader = data_module.test_dataloader()

    correct = 0
    total = 0
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        predictions = ensemble_predict(ensemble, images)
        total += labels.size(0)
        correct += (predictions == labels).sum().item()

    accuracy = 100 * correct / total
    print(f"Ensemble accuracy on experimental data: {accuracy:.2f}%")

Training model 1/25
Fold 1/3
Epoch [1/30], Loss: 5878.8590
Validation Accuracy: 58.77%
Epoch [2/30], Loss: 3571.6743
Validation Accuracy: 68.27%
Epoch [3/30], Loss: 3150.5706
Validation Accuracy: 69.76%
Epoch [4/30], Loss: 3000.6706
Validation Accuracy: 71.15%
Epoch [5/30], Loss: 2892.7960
Validation Accuracy: 71.96%
Epoch [6/30], Loss: 2820.2248
Validation Accuracy: 72.67%
Epoch [7/30], Loss: 2764.6476
Validation Accuracy: 72.27%
Epoch [8/30], Loss: 2719.3309
Validation Accuracy: 72.80%
Epoch [9/30], Loss: 2671.9581
Validation Accuracy: 72.93%
Epoch [10/30], Loss: 2638.2961
Validation Accuracy: 73.25%
Epoch [11/30], Loss: 2615.6605
Validation Accuracy: 73.64%
Epoch [12/30], Loss: 2585.4855
Validation Accuracy: 73.33%
Epoch [13/30], Loss: 2556.4389
Validation Accuracy: 73.23%
Epoch [14/30], Loss: 2550.7736
Validation Accuracy: 74.05%
Epoch [15/30], Loss: 2518.8984
Validation Accuracy: 73.77%
Epoch [16/30], Loss: 2501.6597
Validation Accuracy: 74.07%
Epoch [17/30], Loss: 2494.4732
Valid

AttributeError: Can't pickle local object 'load_augmented_mnist.<locals>.<lambda>'