In [1]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torch.utils.data import Dataset
import torchvision.utils as utils
import matplotlib.pyplot as plt
from torchmetrics import Accuracy
from torchinfo import summary
from pathlib import Path
from PIL import Image
import pickle

In [2]:
if torch.cuda.is_available(): 
    dev = "cuda:0"
else: 
    dev = "cpu"

In [3]:
def load_dataset(file):
    with open(file, 'rb') as fo:
        file_output = pickle.load(fo, encoding='bytes')
    return file_output

In [4]:
meta_data = load_dataset('/home/jhermosilla/Proyects/Datasets/cifar-100-python/meta')
superclasses_set = meta_data[b'coarse_label_names']

In [5]:
train_data = load_dataset('/home/jhermosilla/Proyects/Datasets/cifar-100-python/train')
train_images = train_data[b'data']
train_images = train_images.reshape(-1, 3, 32, 32).transpose(0, 2, 3, 1)
train_labels = train_data[b'coarse_labels']

In [6]:
test_data = load_dataset('/home/jhermosilla/Proyects/Datasets/cifar-100-python/test')
test_images = test_data[b'data']
test_images = test_images.reshape(-1, 3, 32, 32).transpose(0, 2, 3, 1)
test_labels = test_data[b'coarse_labels']

In [7]:
class Data2Tuple(torch.utils.data.Dataset):
    def __init__(self, images, labels, transform=None):
        self.images = images
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = self.images[idx]
        label = self.labels[idx]
        image = Image.fromarray(image)  
        if self.transform:
            image = self.transform(image)
        return image, label

In [9]:
class Learning_class():
    def __init__(self, model):
        self.epochs = 30
        self.device = torch.device(dev)
        self.model = model.to(self.device)
        self.loss_func = nn.CrossEntropyLoss()
        self.accuracy = Accuracy(task='multiclass', num_classes=20)
        self.accuracy = self.accuracy.to(dev)
        self.optimizer = torch.optim.SGD(self.model.parameters(), lr=0.001, momentum=0.9)
        self.model_path = Path("models")
        self.model_name = "lenet5_cifar.pth"
        self.model_path.mkdir(parents=True, exist_ok=True)
        self.full_path = self.model_path / self.model_name
        self.train_loss_hist = []
        self.train_acc_hist = []
        self.test_loss_hist = []
        self.test_acc_hist = []
        self.history = []

    def train(self, history = False):
        for epoch in range(self.epochs):
            train_loss, train_acc = 0.0, 0.0
            for batch_idx, (images, labels) in enumerate(self.model.train_loader):
                images = images.to(self.device)
                labels = labels.to(self.device)
                self.model.train()
                outputs = self.model(images)
                loss = self.loss_func(outputs, labels)
                with torch.no_grad():
                    train_loss += loss.item()
                acc = self.accuracy(outputs, labels)
                train_acc += acc
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()
            train_loss /= len(self.model.train_loader)
            train_acc /= len(self.model.train_loader)

            test_loss, test_acc = 0.0, 0.0
            self.model.eval()
            with torch.inference_mode():
                for batch_idx, (images, labels) in enumerate(self.model.test_loader):
                    images = images.to(self.device)
                    labels = labels.to(self.device)
                    outputs = self.model(images)
                    loss = self.loss_func(outputs,labels)
                    acc = self.accuracy(outputs,labels)
                    with torch.no_grad():
                        test_loss += loss.item()
                        test_acc += acc
                test_loss /= len(self.model.test_loader)
                test_acc /= len(self.model.test_loader)

            print(f"Epoch: {epoch+1} Train loss: {train_loss: .5f} Train acc: {train_acc: .5f} Test loss: {test_loss: .5f} Test acc: {test_acc: .5f}")
            if (history):
                self.train_loss_hist.append(train_loss)
                self.test_loss_hist.append(test_loss)
                self.train_acc_hist.append(train_acc.tolist())
                self.test_acc_hist.append(test_acc.tolist())

        self.history.append(self.train_loss_hist)
        self.history.append(self.test_loss_hist)
        self.history.append(self.train_acc_hist)
        self.history.append(self.test_acc_hist)
        return self.model, self.history

    def plot_loss(self, history):
        plt.figure(figsize=(5, 5))
        plt.plot(range(1,self.epochs+1),history[0], label='Train', color='red')
        plt.plot(range(1,self.epochs+1),history[1], label='Test', color='green')
        plt.title('Loss history')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.legend()
        plt.show()

    def plot_accuracy(self, history):
        plt.figure(figsize=(5, 5))
        plt.plot(range(1,self.epochs+1),history[2], label='Train', color='red')
        plt.plot(range(1,self.epochs+1),history[3], label='Test', color='green')
        plt.title('Accuracy history')
        plt.xlabel('Epoch')
        plt.ylabel('Accuracy')
        plt.legend()
        plt.show()

    def save_model(self, trained_model):
        print("====================================================================================================")
        print(f"Saving the model: {self.full_path}")
        torch.save(obj=trained_model.state_dict(), f=self.full_path)

    def load_model(self, trained_model):
        trained_model.load_state_dict(torch.load(self.full_path, weights_only=True))
        return trained_model

In [10]:
model = LeNet5()
learning_model = Learning_class(model)

In [11]:
trained_model, history = learning_model.train(history=True)

TypeError: default_collate: batch must contain tensors, numpy arrays, numbers, dicts or lists; found <class 'PIL.Image.Image'>