In [None]:
import torch
import torchvision
from torchvision import datasets, transforms
import torch.optim as optim
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch.nn as nn
import torch.nn.functional as F
import tensorflow as tf


In [None]:
from torch.utils.data import Dataset, DataLoader

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

# ---------------------- Filter to our wanted classes only ---------------------
# Define the classes you want to keep (e.g., class 0 and class 1)
# classes = ('car', 'frog', 'horse', 'ship')
classes_to_keep = [1, 6, 7, 8]
# Load the CIFAR-10 dataset
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()

# Function to filter the dataset by the specified classes
def filter_dataset_by_class(data, labels, classes):
    labels = labels.flatten()
    mask = np.isin(labels, classes)
    filtered_data = data[mask]
    filtered_labels = labels[mask]
    filtered_labels = np.array([classes.index(label) for label in filtered_labels])
    return filtered_data, filtered_labels

# Filter the training and testing datasets
x_train_filtered, y_train_filtered = filter_dataset_by_class(x_train, y_train, classes_to_keep)
x_test_filtered, y_test_filtered = filter_dataset_by_class(x_test, y_test, classes_to_keep)

# Normalize the images to [-1, 1]
x_train_filtered = (x_train_filtered.astype('float32') / 255.0 - 0.5) / 0.5
x_test_filtered = (x_test_filtered.astype('float32') / 255.0 - 0.5) / 0.5
# Center crop images to 28x28
def center_crop(images, crop_size):
    start = (images.shape[1] - crop_size) // 2
    return images[:, start:start+crop_size, start:start+crop_size, :]

# x_train_filtered = center_crop(x_train_filtered, 28)
# x_test_filtered = center_crop(x_test_filtered, 28)

# Create a custom Dataset class for PyTorch
class KerasDataset(Dataset):
    def __init__(self, images, labels):
        self.images = images
        self.labels = labels

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = self.images[idx]
        label = self.labels[idx]
        
        # Convert numpy array to torch tensor
        image = torch.tensor(image).permute(2, 0, 1)  # Convert from HWC to CHW format
        label = torch.tensor(label).long()
        
        return image, label

# Create the `trainloader` and `testloader` equivalent using TensorFlow's tf.data API
batch_size = 32
trainset = KerasDataset(x_train_filtered, y_train_filtered)
testset = KerasDataset(x_test_filtered, y_test_filtered)

# Create trainloader (train dataset)
trainloader = DataLoader(trainset, batch_size=32, shuffle=True, num_workers=2, pin_memory=True)

# Create testloader (test dataset)
testloader = DataLoader(testset, batch_size=32, shuffle=False, num_workers=2, pin_memory=True)

# Example: Iterate through the filtered trainloader
for i, (images, labels) in enumerate(trainloader):
    print(f'Batch size: {images.size(0)}, Labels: {labels}')
    if i == 3:
      break

print(trainset)
print(testset)


In [None]:
classes = ('car', 'frog', 'horse', 'ship')
# functions to show an image
def imshow(img):
    img = img / 2 + 0.5     # unnormalize
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()


# get some random training images
dataiter = iter(trainloader)
images, labels = next(dataiter)

# show images
imshow(torchvision.utils.make_grid(images))
# print labels
print(' '.join(f'{classes[labels[j]]:5s}' for j in range(batch_size)))

In [None]:
def Net_36480():
    class Net(nn.Module):
        def __init__(self):
            super().__init__()
            self.conv1 = nn.Conv2d(3, 6, 5)
            self.pool = nn.MaxPool2d(2, 2)
            self.conv2 = nn.Conv2d(6, 10, 5)
            self.fc1 = nn.Linear(10 * 5 * 5, 110)
            self.fc2 = nn.Linear(110, 60)
            self.fc3 = nn.Linear(60, 4)

        def forward(self, x):
            x = self.pool(F.relu(self.conv1(x)))
            x = self.pool(F.relu(self.conv2(x)))
            x = torch.flatten(x, 1)
            x = F.relu(self.fc1(x))
            x = F.relu(self.fc2(x))
            x = self.fc3(x)
            return x

    return Net()
    
def Net_10608():
    class Net(nn.Module):
        def __init__(self):
            super().__init__()
            self.conv1 = nn.Conv2d(3, 6, 5)  # Conv layer with 6 output channels
            self.pool = nn.MaxPool2d(2, 2)   # Max pooling layer with 2x2 window
            self.conv2 = nn.Conv2d(6, 6, 5)  # Conv layer with 6 output channels
            self.fc1 = nn.Linear(6 * 5 * 5, 40)  # Fully connected layer
            self.fc2 = nn.Linear(40, 20)  # Fully connected layer
            self.fc3 = nn.Linear(20, 4)   # Fully connected layer

        def forward(self, x):
            x = self.pool(F.relu(self.conv1(x)))
            x = self.pool(F.relu(self.conv2(x)))
            x = torch.flatten(x, 1)  # flatten all dimensions except batch
            x = F.relu(self.fc1(x))
            x = F.relu(self.fc2(x))
            x = self.fc3(x)
            return x

    return Net()

def Net_13093():
    class Net(nn.Module):
        def __init__(self):
            super().__init__()
            self.conv1 = nn.Conv2d(3, 6, 5)
            self.pool = nn.MaxPool2d(2, 2)
            self.conv2 = nn.Conv2d(6, 8, 5)
            self.fc1 = nn.Linear(8 * 5 * 5, 50)
            self.fc2 = nn.Linear(50, 25)
            self.fc3 = nn.Linear(25, 4)

        def forward(self, x):
            x = self.pool(F.relu(self.conv1(x)))
            x = self.pool(F.relu(self.conv2(x)))
            x = torch.flatten(x, 1)
            x = F.relu(self.fc1(x))
            x = F.relu(self.fc2(x))
            x = self.fc3(x)
            return x

    return Net()
    
def Net_15895():
    class Net(nn.Module):
        def __init__(self):
            super().__init__()
            self.conv1 = nn.Conv2d(3, 6, 5)  # Conv layer with 6 output channels
            self.pool = nn.MaxPool2d(2, 2)   # Max pooling layer with 2x2 window
            self.conv2 = nn.Conv2d(6, 10, 5)  # Conv layer with 8 output channels
            self.fc1 = nn.Linear(10 * 5 * 5, 50)  # Fully connected layer
            self.fc2 = nn.Linear(50, 25)  # Fully connected layer
            self.fc3 = nn.Linear(25, 4)   # Fully connected layer

        def forward(self, x):
            x = self.pool(F.relu(self.conv1(x)))
            x = self.pool(F.relu(self.conv2(x)))
            x = torch.flatten(x, 1)  # flatten all dimensions except batch
            x = F.relu(self.fc1(x))
            x = F.relu(self.fc2(x))
            x = self.fc3(x)
            return x

    return Net()


def Net_39590():
    class Net(nn.Module):
        def __init__(self):
            super().__init__()
            self.conv1 = nn.Conv2d(3, 6, 5)
            self.pool = nn.MaxPool2d(2, 2)
            self.conv2 = nn.Conv2d(6, 10, 5)
            self.fc1 = nn.Linear(10 * 5 * 5, 120)
            self.fc2 = nn.Linear(120, 60)
            self.fc3 = nn.Linear(60, 4)

        def forward(self, x):
            x = self.pool(F.relu(self.conv1(x)))
            x = self.pool(F.relu(self.conv2(x)))
            x = torch.flatten(x, 1)
            x = F.relu(self.fc1(x))
            x = F.relu(self.fc2(x))
            x = self.fc3(x)
            return x

    return Net()
    
def Net_23299():
    class Net(nn.Module):
        def __init__(self):
            super().__init__()
            self.conv1 = nn.Conv2d(3, 6, 5)  # Conv layer with 6 output channels
            self.pool = nn.MaxPool2d(2, 2)   # Max pooling layer with 2x2 window
            self.conv2 = nn.Conv2d(6, 9, 5)  # Conv layer with 9 output channels
            self.fc1 = nn.Linear(9 * 5 * 5, 80)  # Fully connected layer
            self.fc2 = nn.Linear(80, 40)  # Fully connected layer
            self.fc3 = nn.Linear(40, 4)   # Fully connected layer

        def forward(self, x):
            x = self.pool(F.relu(self.conv1(x)))
            x = self.pool(F.relu(self.conv2(x)))
            x = torch.flatten(x, 1)  # flatten all dimensions except batch
            x = F.relu(self.fc1(x))
            x = F.relu(self.fc2(x))
            x = self.fc3(x)
            return x

    return Net()

def Net_27150():
    class Net(nn.Module):
        def __init__(self):
            super().__init__()
            self.conv1 = nn.Conv2d(3, 6, 5)  # Conv layer with 6 output channels
            self.pool = nn.MaxPool2d(2, 2)   # Max pooling layer with 2x2 window
            self.conv2 = nn.Conv2d(6, 10, 5)  # Conv layer with 10 output channels
            self.fc1 = nn.Linear(10 * 5 * 5, 80)  # Fully connected layer
            self.fc2 = nn.Linear(80, 60)  # Fully connected layer
            self.fc3 = nn.Linear(60, 4)   # Fully connected layer

        def forward(self, x):
            x = self.pool(F.relu(self.conv1(x)))
            x = self.pool(F.relu(self.conv2(x)))
            x = torch.flatten(x, 1)  # flatten all dimensions except batch
            x = F.relu(self.fc1(x))
            x = F.relu(self.fc2(x))
            x = self.fc3(x)
            return x

    return Net()

def Net_45892():
    class Net(nn.Module):
        def __init__(self):
            super().__init__()
            self.conv1 = nn.Conv2d(3, 6, 5)
            self.pool = nn.MaxPool2d(2, 2)
            self.conv2 = nn.Conv2d(6, 12, 5)
            self.fc1 = nn.Linear(12 * 5 * 5, 120)
            self.fc2 = nn.Linear(120, 60)
            self.fc3 = nn.Linear(60, 4)

        def forward(self, x):
            x = self.pool(F.relu(self.conv1(x)))
            x = self.pool(F.relu(self.conv2(x)))
            x = torch.flatten(x, 1)
            x = F.relu(self.fc1(x))
            x = F.relu(self.fc2(x))
            x = self.fc3(x)
            return x

    return Net()

def Net_42052():
    class Net(nn.Module):
        def __init__(self):
            super().__init__()
            self.conv1 = nn.Conv2d(3, 6, 5)
            self.pool = nn.MaxPool2d(2, 2)
            self.conv2 = nn.Conv2d(6, 12, 5)
            self.fc1 = nn.Linear(12 * 5 * 5, 110)
            self.fc2 = nn.Linear(110, 58)
            self.fc3 = nn.Linear(58, 4)

        def forward(self, x):
            x = self.pool(F.relu(self.conv1(x)))
            x = self.pool(F.relu(self.conv2(x)))
            x = torch.flatten(x, 1)
            x = F.relu(self.fc1(x))
            x = F.relu(self.fc2(x))
            x = self.fc3(x)
            return x

    return Net()

def Net_30397():
  class Net(nn.Module):
      def __init__(self):
          super().__init__()
          self.conv1 = nn.Conv2d(3, 6, 5)
          self.pool = nn.MaxPool2d(2, 2)
          self.conv2 = nn.Conv2d(6, 10, 5)
          self.fc1 = nn.Linear(10 * 5 * 5, 92)
          self.fc2 = nn.Linear(92, 55)
          self.fc3 = nn.Linear(55, 4)

      def forward(self, x):
          x = self.pool(F.relu(self.conv1(x)))
          x = self.pool(F.relu(self.conv2(x)))
          x = torch.flatten(x, 1)
          x = F.relu(self.fc1(x))
          x = F.relu(self.fc2(x))
          x = self.fc3(x)
          return x
  return Net()


def Net_33370():
    class Net(nn.Module):
        def __init__(self):
            super().__init__()
            self.conv1 = nn.Conv2d(3, 6, 5)
            self.pool = nn.MaxPool2d(2, 2)
            self.conv2 = nn.Conv2d(6, 10, 5)
            self.fc1 = nn.Linear(10 * 5 * 5, 100)
            self.fc2 = nn.Linear(100, 60)
            self.fc3 = nn.Linear(60, 4)

        def forward(self, x):
            x = self.pool(F.relu(self.conv1(x)))
            x = self.pool(F.relu(self.conv2(x)))
            x = torch.flatten(x, 1)
            x = F.relu(self.fc1(x))
            x = F.relu(self.fc2(x))
            x = self.fc3(x)
            return x

    return Net()

def Net_16819():
    class Net(nn.Module):
        def __init__(self):
            super().__init__()
            self.conv1 = nn.Conv2d(3, 6, 5)  # Keep conv1 as is
            self.pool = nn.MaxPool2d(2, 2)
    
            # Reduce the number of output channels in conv2 slightly
            self.conv2 = nn.Conv2d(6, 9, 5)  # Adjust conv2 to 11 output channels
    
            # Adjust the fully connected layers slightly to reduce parameters
            self.fc1 = nn.Linear(9 * 5 * 5, 80)  # Adjust fc1 to 110 output units
            self.fc2 = nn.Linear(80, 40)  # Adjust fc2 to 55 output units
            self.fc3 = nn.Linear(40, 4)  # Keep fc3 as is
    
        def forward(self, x):
            x = self.pool(F.relu(self.conv1(x)))
            x = self.pool(F.relu(self.conv2(x)))
            x = torch.flatten(x, 1)  # flatten all dimensions except batch
            x = F.relu(self.fc1(x))
            x = F.relu(self.fc2(x))
            x = self.fc3(x)
            return x
    return Net()

def Net_3305():
    class Net(nn.Module):
        def __init__(self):
            super().__init__()
            self.conv1 = nn.Conv2d(3, 6, 5)  # Keep conv1 as is
            self.pool = nn.MaxPool2d(2, 2)
    
            # Reduce the number of output channels in conv2 slightly
            self.conv2 = nn.Conv2d(6, 5, 5)  # Adjust conv2 to 11 output channels
    
            # Adjust the fully connected layers slightly to reduce parameters
            self.fc1 = nn.Linear(5 * 5 * 5, 15)  # Adjust fc1 to 110 output units
            self.fc2 = nn.Linear(15, 10)  # Adjust fc2 to 55 output units
            self.fc3 = nn.Linear(10, 4)  # Keep fc3 as is
    
        def forward(self, x):
            x = self.pool(F.relu(self.conv1(x)))
            x = self.pool(F.relu(self.conv2(x)))
            x = torch.flatten(x, 1)  # flatten all dimensions except batch
            x = F.relu(self.fc1(x))
            x = F.relu(self.fc2(x))
            x = self.fc3(x)
            return x
    return Net()

def Net_5695():
    class Net(nn.Module):
        def __init__(self):
            super().__init__()
            self.conv1 = nn.Conv2d(3, 6, 5)  # Keep conv1 as is
            self.pool = nn.MaxPool2d(2, 2)
    
            # Reduce the number of output channels in conv2 slightly
            self.conv2 = nn.Conv2d(6, 5, 5)  # Adjust conv2 to 11 output channels
    
            # Adjust the fully connected layers slightly to reduce parameters
            self.fc1 = nn.Linear(5 * 5 * 5, 30)  # Adjust fc1 to 110 output units
            self.fc2 = nn.Linear(30, 20)  # Adjust fc2 to 55 output units
            self.fc3 = nn.Linear(20, 4)  # Keep fc3 as is
    
        def forward(self, x):
            x = self.pool(F.relu(self.conv1(x)))
            x = self.pool(F.relu(self.conv2(x)))
            x = torch.flatten(x, 1)  # flatten all dimensions except batch
            x = F.relu(self.fc1(x))
            x = F.relu(self.fc2(x))
            x = self.fc3(x)
            return x
    return Net()

def Net_62716():
    class Net(nn.Module):
        def __init__(self):
            super().__init__()
            self.conv1 = nn.Conv2d(3, 6, 5)  # Keep conv1 as is
            self.pool = nn.MaxPool2d(2, 2)
    
            # Reduce the number of output channels in conv2 slightly
            self.conv2 = nn.Conv2d(6, 16, 5)  # Adjust conv2 to 11 output channels
    
            # Adjust the fully connected layers slightly to reduce parameters
            self.fc1 = nn.Linear(16 * 5 * 5, 128)  # Adjust fc1 to 110 output units
            self.fc2 = nn.Linear(128, 64)  # Adjust fc2 to 55 output units
            self.fc3 = nn.Linear(64, 4)  # Keep fc3 as is
    
        def forward(self, x):
            x = self.pool(F.relu(self.conv1(x)))
            x = self.pool(F.relu(self.conv2(x)))
            x = torch.flatten(x, 1)  # flatten all dimensions except batch
            x = F.relu(self.fc1(x))
            x = F.relu(self.fc2(x))
            x = self.fc3(x)
            return x
    return Net()

def Net_54694():
    class Net(nn.Module):
        def __init__(self):
            super().__init__()
            self.conv1 = nn.Conv2d(3, 6, 5)  # Keep conv1 as is
            self.pool = nn.MaxPool2d(2, 2)
    
            # Reduce the number of output channels in conv2 slightly
            self.conv2 = nn.Conv2d(6, 14, 5)  # Adjust conv2 to 11 output channels
    
            # Adjust the fully connected layers slightly to reduce parameters
            self.fc1 = nn.Linear(14 * 5 * 5, 120)  # Adjust fc1 to 110 output units
            self.fc2 = nn.Linear(120, 80)  # Adjust fc2 to 55 output units
            self.fc3 = nn.Linear(80, 4)  # Keep fc3 as is
    
        def forward(self, x):
            x = self.pool(F.relu(self.conv1(x)))
            x = self.pool(F.relu(self.conv2(x)))
            x = torch.flatten(x, 1)  # flatten all dimensions except batch
            x = F.relu(self.fc1(x))
            x = F.relu(self.fc2(x))
            x = self.fc3(x)
            return x
    return Net()

def Net_75426():
    class Net(nn.Module):
        def __init__(self):
            super().__init__()
            self.conv1 = nn.Conv2d(3, 6, 5)  # Keep conv1 as is
            self.pool = nn.MaxPool2d(2, 2)
    
            # Reduce the number of output channels in conv2 slightly
            self.conv2 = nn.Conv2d(6, 16, 5)  # Adjust conv2 to 11 output channels
    
            # Adjust the fully connected layers slightly to reduce parameters
            self.fc1 = nn.Linear(16 * 5 * 5, 150)  # Adjust fc1 to 110 output units
            self.fc2 = nn.Linear(150, 80)  # Adjust fc2 to 55 output units
            self.fc3 = nn.Linear(80, 4)  # Keep fc3 as is
    
        def forward(self, x):
            x = self.pool(F.relu(self.conv1(x)))
            x = self.pool(F.relu(self.conv2(x)))
            x = torch.flatten(x, 1)  # flatten all dimensions except batch
            x = F.relu(self.fc1(x))
            x = F.relu(self.fc2(x))
            x = self.fc3(x)
            return x
    return Net()

# Mapping from parameter count to the corresponding network definition
models_to_train = {
    10608: Net_10608,
    36480: Net_36480,
    15895: Net_15895,
    16819: Net_16819,
    45892: Net_45892,
    30397: Net_30397,
    27150: Net_27150,
    3305: Net_3305,
    23299: Net_23299,
    5695: Net_5695,
    42052: Net_42052,
    72426: Net_75426,
    54694: Net_54694,
}


In [None]:
models_to_train = {
    10608: {'net': Net_10608, 'epochs': 50},
    36480: {'net': Net_36480, 'epochs': 50},
    15895: {'net': Net_15895, 'epochs': 50},
    16819: {'net': Net_16819, 'epochs': 50},
    45892: {'net': Net_45892, 'epochs': 50},
    30397: {'net': Net_30397, 'epochs': 50},
    27150: {'net': Net_27150, 'epochs': 50},
    3305: {'net': Net_3305, 'epochs': 50},
    23299: {'net': Net_23299, 'epochs': 50},
    5695: {'net': Net_5695, 'epochs': 50},
    42052: {'net': Net_42052, 'epochs': 50},
}

# Directory to save the models and results
os.makedirs('result_noise', exist_ok=True)
# Function to add noise to the labels
def add_label_noise(labels, noise_ratio=0.15, num_classes=4):
    noisy_labels = labels.clone()
    num_noisy = int(noise_ratio * len(labels))
    
    # Randomly select indices to corrupt
    noisy_indices = np.random.choice(len(labels), num_noisy, replace=False)
    
    # Replace the selected labels with random classes
    noisy_labels[noisy_indices] = torch.randint(0, num_classes, size=(num_noisy,))
    
    return noisy_labels

# Training function
def train_model(model, trainloader, testloader, criterion, optimizer, n_epochs, save_path):
    model.train()
    train_accuracy = []
    test_accuracy = []
    train_losses = []
    test_losses = []

    for epoch in range(n_epochs):
        running_loss = 0.0
        for i, data in enumerate(trainloader, 0):
            inputs, labels = data

            # Add noise to the labels
            noisy_labels = add_label_noise(labels)

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, noisy_labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

        # Calculate and store train and test accuracy/loss
        train_acc, train_loss = calculate_accuracy_and_loss(model, trainloader, criterion)
        test_acc, test_loss = calculate_accuracy_and_loss(model, testloader, criterion, test=True)

        train_accuracy.append(train_acc)
        test_accuracy.append(test_acc)
        train_losses.append(train_loss)
        test_losses.append(test_loss)

        print(f'Epoch {epoch+1}, Train accuracy: {train_acc}, Test accuracy: {test_acc}, Train loss: {train_loss}, Test loss: {test_loss}')

    # Save the model after training
    new_model_path = os.path.join('result_noise', save_path.replace('.pth', '_noisy.pth'))
    torch.save(model.state_dict(), new_model_path)

    # Save the results as a CSV
    new_results_path = os.path.join('result_noise', save_path.replace('.pth', '_noisy_results.csv'))
    results = pd.DataFrame({
        'train_accuracy': train_accuracy,
        'test_accuracy': test_accuracy,
        'train_loss': train_losses,
        'test_loss': test_losses
    })
    results.to_csv(new_results_path, index=False)
    print(f'Saved model to {new_model_path} and results to {new_results_path}')

def calculate_accuracy_and_loss(model, dataloader, criterion, test=False):
    model.eval()
    correct = 0
    total = 0
    running_loss = 0.0
    with torch.no_grad():
        for data in dataloader:
            images, labels = data
            outputs = model(images)
            loss = criterion(outputs, labels)
            running_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct // total
    avg_loss = running_loss / len(dataloader)
    return accuracy, avg_loss

# Initialize the loss function
criterion = nn.CrossEntropyLoss()

# Iterate over each model, create it from scratch, and train it with noisy labels
for key, model_info in models_to_train.items():
    print(f"Processing model with {key} parameters")

    # Instantiate the model using the correct architecture
    model = model_info['net']()  # Call the function to create the network
    
    # Initialize the optimizer
    optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.90)

    # Train the model from scratch
    train_model(model, trainloader, testloader, criterion, optimizer, model_info['epochs'], f'model_{key}.pth')


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
import os
import pandas as pd
net_dict = {
    10608: Net_10608,
    36480: Net_36480,
    15895: Net_15895,
    16819: Net_16819,
    45892: Net_45892,
    30397: Net_30397,
    27150: Net_27150,
    3305: Net_3305,
    23299: Net_23299,
    5695: Net_5695,
    42052: Net_42052,
    75426: Net_75426,
    54694: Net_54694,
    39590: Net_39590,
    33370: Net_33370,
    13093: Net_13093
}
# Define the paths and number of additional epochs for each model
models_to_train = {
    # '13093': {'path': '/cs_storage/itayab/PyCharmProjects/Rethinking_Generalization/src/results/13093/model_13093.pth', 'additional_epochs': 50},
    # '33370': {'path': '/content/drive/MyDrive/MSc courses/OML project/4_class_cnn/33370/model_33370.pth', 'additional_epochs': 30},
    # '36480': {'path': '/content/drive/MyDrive/MSc courses/OML project/4_class_cnn/36480/model_36480.pth', 'additional_epochs': 30},
    # '39590': {'path': '/content/drive/MyDrive/MSc courses/OML project/4_class_cnn/39590/model_39590.pth', 'additional_epochs': 30},
    # '3305': {'path': '/cs_storage/itayab/PyCharmProjects/Rethinking_Generalization/src/results/3305/model_3305.pth', 'additional_epochs': 50},
    # '15895' : {'path': '/cs_storage/itayab/PyCharmProjects/Rethinking_Generalization/src/results/15895/model_15895.pth', 'additional_epochs': 50},
    '16819' : {'path': '/cs_storage/itayab/PyCharmProjects/Rethinking_Generalization/src/results/16819/model_16819.pth', 'additional_epochs': 50}
}


# Training function
def train_model(model, trainloader, testloader, criterion, optimizer, n_epochs, save_path):
    model.train()
    train_accuracy = []
    test_accuracy = []
    train_losses = []
    test_losses = []

    for epoch in range(n_epochs):
        running_loss = 0.0
        for i, data in enumerate(trainloader, 0):
            inputs, labels = data
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

        # Calculate and store train and test accuracy/loss
        train_acc, train_loss = calculate_accuracy_and_loss(model, trainloader, criterion)
        test_acc, test_loss = calculate_accuracy_and_loss(model, testloader, criterion, test=True)

        train_accuracy.append(train_acc)
        test_accuracy.append(test_acc)
        train_losses.append(train_loss)
        test_losses.append(test_loss)

        print(f'Epoch {epoch+1}, Train accuracy: {train_acc}, Test accuracy: {test_acc}, Train loss: {train_loss}, Test loss: {test_loss}')

    # Save the model after training with a new name
    new_model_path = save_path.replace('.pth', '_retrained.pth')
    torch.save(model.state_dict(), new_model_path)

    # Save the results as a CSV with a new name
    new_results_path = save_path.replace('.pth', '_retrained_results.csv')
    results = pd.DataFrame({
        'train_accuracy': train_accuracy,
        'test_accuracy': test_accuracy,
        'train_loss': train_losses,
        'test_loss': test_losses
    })
    results.to_csv(new_results_path, index=False)
    print(f'Saved model to {new_model_path} and results to {new_results_path}')

def calculate_accuracy_and_loss(model, dataloader, criterion, test=False):
    model.eval()
    correct = 0
    total = 0
    running_loss = 0.0
    with torch.no_grad():
        for data in dataloader:
            images, labels = data
            outputs = model(images)
            loss = criterion(outputs, labels)
            running_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct // total
    avg_loss = running_loss / len(dataloader)
    return accuracy, avg_loss

# Initialize the loss function
criterion = nn.CrossEntropyLoss()

# Iterate over each model, load it, and continue training
for key, model_info in models_to_train.items():
    print(f"Processing model with {key} parameters")

    # Convert the key to an integer and use it to get the correct Net class
    param_count = int(key)
    model = net_dict[param_count]()
    print(model_info['path'])
    # Load the state dictionary into the instantiated model
    model.load_state_dict(torch.load(model_info['path']))

    # Initialize the optimizer
    optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.90)

    # Continue training the model for the additional epochs
    train_model(model, trainloader, testloader, criterion, optimizer, model_info['additional_epochs'], model_info['path'])


In [None]:
import os
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np

# Define the directory containing the results
directory = './results/'

# Prepare lists to collect data
param_counts = []
final_train_accuracies = []
final_test_accuracies = []
final_train_losses = []
final_test_losses = []

# List all subdirectories
subdirs = [d for d in os.listdir(directory) if os.path.isdir(os.path.join(directory, d))]

# Iterate over each subdirectory to find and process the CSV files
for subdir in subdirs:
    subdir_path = os.path.join(directory, subdir)

    # Extract the number of parameters from the directory name
    try:
        num_params = int(subdir)
        param_counts.append(num_params)
    except ValueError:
        # Skip directories that are not named after a number of parameters
        continue

    # Check for retrained results first
    retrained_csv_path = os.path.join(subdir_path, f"model_{subdir}_retrained_results.csv")
    original_csv_path = os.path.join(subdir_path, f"results_{subdir}.csv")

    if os.path.exists(retrained_csv_path):
        csv_path = retrained_csv_path
    elif os.path.exists(original_csv_path):
        csv_path = original_csv_path
    else:
        continue  # Skip if neither file exists

    # Read the CSV file
    df = pd.read_csv(csv_path)

    # Extract the final epoch's metrics
    final_train_accuracies.append((100 - df['train_accuracy'].max()) / 100)
    final_test_accuracies.append((100 - df['test_accuracy'][-5:].mean()) / 100)
    final_train_losses.append(df['train_loss'][-30:].min())
    final_test_losses.append(df['test_loss'][-30:].min())

# Sort the results by the number of parameters
sorted_indices = sorted(range(len(param_counts)), key=lambda i: param_counts[i])
param_counts = np.array([param_counts[i] for i in sorted_indices])
final_train_accuracies = [final_train_accuracies[i] for i in sorted_indices]
final_test_accuracies = [final_test_accuracies[i] for i in sorted_indices]
final_train_losses = [final_train_losses[i] for i in sorted_indices]
final_test_losses = [final_test_losses[i] for i in sorted_indices]

# Plotting Test/Train Error vs Number of Parameters without interpolation
plt.figure(figsize=(12, 6))

# Plot the actual test and train error data points
plt.plot(param_counts, final_test_accuracies, '-', label='Test Error', color='blue', linewidth=2)
plt.plot(param_counts, final_train_accuracies, '--', label='Train Error', color='blue', alpha=0.5)

# Add the vertical line at 20,000 parameters for interpolation
plt.axvline(x=20000, color='brown', linestyle='--', linewidth=2, label='Interpolation Point at 20K')

# Highlight critical regime between 15k and 31k parameters
critical_start = 10608
critical_end = 33370
plt.axvspan(critical_start, critical_end, color='orange', alpha=0.3, label='Critical Regime')

# Add arrow and annotation for critical regime
plt.annotate('Critical Regime', xy=(critical_start, 0.4), xytext=(critical_start + 2000, 0.5),
             arrowprops=dict(facecolor='orange', shrink=0.05), fontsize=12, color='orange')

# Customizing the plot
plt.xlabel('Number of Parameters', fontsize=14)
plt.ylabel('Error', fontsize=14)
plt.ylim(-0.005, 0.16)
plt.title('Test and train Error vs Number of Parameters', fontsize=16)
plt.xticks(np.arange(0, param_counts.max() + 10000, step=10000))
plt.grid(True)
plt.legend(loc='best', fontsize=12)

# Show plot
plt.tight_layout()
plt.show()
