In [None]:
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
import torch
from torchvision.transforms.functional import to_tensor, to_pil_image

In [None]:
import os
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms


def unpickle(file):
    import pickle
    with open(file, 'rb') as fo:
        dict = pickle.load(fo, encoding='bytes')
    return dict


"""
1.  Define and build a PyTorch Dataset
"""
class CIFAR10(Dataset):
    def __init__(self, data_files, transform=None, target_transform=None):
        
        self.image_data = []
        self.image_labels = []
        for i in range(len(data_files)):
            data_dict = unpickle(data_files[i])
            # self.image_data += data_dict[b'data']
            if isinstance(self.image_data, list):
                self.image_data = data_dict[b'data']
            else:
                # Stack vertically (for 2D data)
                self.image_data = np.vstack((self.image_data, data_dict[b'data']))
            
            self.image_labels += data_dict[b'labels']

        self.transform = transform
        self.target_transform = target_transform

        # print(type(self.image_data[0]))

    def __len__(self):

        return len(self.image_data)

    def __getitem__(self, idx):
        
        image = self.image_data[idx]
        image = image.reshape(3,32,32)
        image = image.transpose(1,2,0)
        label = self.image_labels[idx]
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)
            
        return (image, label)
    

def get_preprocess_transform(mode):

    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    return transform


def build_dataset(data_files, transform=None):

    dataset = CIFAR10(data_files, transform)

    return dataset



"""
2.  Build a PyTorch DataLoader
"""
def build_dataloader(dataset, loader_params):

    dataloader = DataLoader(dataset, batch_size=loader_params["batch_size"], shuffle=loader_params["shuffle"])
    # dataloader = DataLoader(dataset)

    return dataloader


"""
3. (a) Build a neural network class.
"""
class FinetuneNet(torch.nn.Module):
    def __init__(self):
        
        super().__init__()

        self.conv1 = torch.nn.Conv2d(3, 16, 3, padding=1)
        self.pool1 = torch.nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = torch.nn.Conv2d(16, 32, 3, padding=1)
        
        self.hidden1 = torch.nn.Linear(32 * 8 * 8, 250)
        self.hidden2 = torch.nn.Linear(250, 10)
        self.relu = torch.nn.ReLU() 


    def forward(self, x):
        # x = self.unflatten(x)
        x = self.conv1(x)
        x = self.relu(x)
        x = self.pool1(x)

        x = self.conv2(x)
        x = self.relu(x)
        x = self.pool1(x)

        x = x.view(-1, 32 * 8 * 8)

        x = self.hidden1(x)
        x = self.relu(x)
        x = self.hidden2(x)

        return x


"""
3. (b)  Build a model
"""
def build_model(trained=False):

    net = FinetuneNet()

    return net


"""
4.  Build a PyTorch optimizer
"""
def build_optimizer(optim_type, model_params, hparams):

    if optim_type == "SGD":
        optimizer = torch.optim.SGD(params=model_params, lr=hparams)
    if optim_type == "Adam":
        optimizer = torch.optim.Adam(params=model_params, lr=hparams)

    return optimizer


"""
5. Training loop for model
"""
def train(train_dataloader, model, loss_fn, optimizer):
    loss_history = []
    accuracy_history = []

    for epoch in range(38):
        total_loss = 0
        correct_predictions = 0
        total_predictions = 0

        for batch, (X, y) in enumerate(train_dataloader):
            # Compute prediction and loss
            pred = model(X)
            loss = loss_fn(pred, y)

            # Backpropagation
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # losses.append(loss.detach().numpy())
            # losses.append(float(loss))
            total_loss += loss.item()
            _, predicted_labels = torch.max(pred, 1)
            correct_predictions += (predicted_labels == y).sum().item()
            total_predictions += y.size(0)

        epoch_loss = total_loss / len(train_dataloader)
        epoch_accuracy = correct_predictions / total_predictions
        loss_history.append(epoch_loss)
        accuracy_history.append(epoch_accuracy)
        
        print(f'Epoch {epoch+1}, Loss: {epoch_loss}, Accuracy: {epoch_accuracy}')

    # plt.plot(losses)
    plt.figure(figsize=(10, 5))
    plt.plot(loss_history, label='Loss')
    plt.title('Training Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    plt.show()

    # Plotting the accuracy
    plt.figure(figsize=(10, 5))
    plt.plot(accuracy_history, label='Accuracy', color='orange')
    plt.title('Training Accuracy')
    plt.xlabel('Epochs')
    plt.ylabel('Accuracy')
    plt.legend()
    plt.show()


'''Run model'''
def run_model():

    new_model = build_model()
    train_dataset = build_dataset(['cifar-10-batches-py/data_batch_1', 
                                   'cifar-10-batches-py/data_batch_2', 
                                   'cifar-10-batches-py/data_batch_3', 
                                   'cifar-10-batches-py/data_batch_4', 
                                   'cifar-10-batches-py/data_batch_5'], transform=get_preprocess_transform(train))
    
    train_params = {"batch_size": 64, "shuffle": True}
    train_dataloader = build_dataloader(train_dataset, train_params)   
    # train_dataloader = build_dataloader(train_dataset)

    loss_fn = torch.nn.CrossEntropyLoss()
    optimizer = build_optimizer("Adam", new_model.parameters(), hparams=0.01)
    train(train_dataloader, new_model, loss_fn, optimizer)

    return new_model

In [None]:
from torchvision import transforms

# it make take a little while to build the dataset 
example_dataset = build_dataset(["cifar-10-batches-py/data_batch_1"], transform=transforms.ToTensor())

In [None]:
run_model()