In [None]:
import numpy as np
import matplotlib.pyplot as plt
import torch
from torch.utils.data.dataset import Dataset
from torchvision import transforms
from torch.utils.data import DataLoader
import torch.nn as nn
import torchmetrics

from torchvision import datasets

In [None]:
class DatasetCIFAR(Dataset):

    def __init__(self, x_data, y_data, transform=None):
        self.x_data = x_data
        self.y_data = y_data
        self.transform = transform

    def __getitem__(self, index):
        """Load and return a sample from the dataset at the given index."""
        img = self.x_data[index]

        # augmentations
        if self.transform is not None:
            img = self.transform(img)

        label = torch.from_numpy(self.y_data[index])

        return img, label

    def __len__(self):
        """Return the number of samples in dataset."""
        return len(self.x_data)

In [None]:
class DatamoduleCIFAR():
    """Create dataset and loaders, apply transforms."""

    def __init__(self):
        # load data
        trainset = datasets.CIFAR10(root='./data', train=True, download=True)
        valset = datasets.CIFAR10(root='./data', train=False, download=True)

        self.x_train = np.stack([np.array(img)/255. for img, _ in trainset])  # (N, H, W, C)
        self.y_train = np.array([label for _, label in trainset], dtype=int)[:, None]

        self.x_test = np.stack([np.array(img)/255. for img, _ in valset])
        self.y_test = np.array([label for _, label in valset], dtype=int)[:, None]

    def create_loaders(self):
        """Create loaders both for train and test/validation datasets."""
        # train dataset
        dset_train = DatasetCIFAR(self.x_train, self.y_train, transform=transforms.ToTensor())
        # test dataset
        dset_test = DatasetCIFAR(self.x_test, self.y_test, transform=transforms.ToTensor())

        # Train and test dataloaders
        train_loader = DataLoader(dset_train, batch_size=100, shuffle=True)
        test_loader = DataLoader(dset_test, batch_size=100, shuffle=False)

        return train_loader, test_loader

In [None]:
class ModelCIFAR(nn.Module):

    def __init__(self):
        super().__init__()

        # CNN
        self.cnn = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, padding='same'),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2),

            nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, padding='same'),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2),

            nn.Conv2d(in_channels=32, out_channels=16, kernel_size=3, padding='same'),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2),

            nn.Flatten(),
            nn.Linear(in_features=4 * 4 * 16, out_features=10)
        )

        self.loss_ce = nn.CrossEntropyLoss()
        self.optimizer = torch.optim.AdamW(self.cnn.parameters(), lr=1e-3)

        # Metrics
        self.accuracy = torchmetrics.Accuracy(task='multiclass', num_classes=10)
        self.prec = torchmetrics.Precision(task='multiclass', num_classes=10)

    def forward(self, x):
        return self.cnn(x)

    def fit(self, train_loader, test_loader, num_epoch=50):

        for ii in range(num_epoch):

            loss_batches = []
            preds_train = []
            labels_train = []
            # train
            for step, (images, labels) in enumerate(train_loader):
                # to cuda
                # images = images #.cuda()
                # labels = labels #.cuda()
                self.cnn.train()
                # make prediction
                logits_cls = self.cnn(images.float())

                # calculate loss
                loss = self.loss_ce(logits_cls, labels[:, 0])

                # update weights
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

                # save loss
                loss_batches.append(loss.item())

                # predictions
                labels_pred = torch.argmax(nn.Softmax(dim=1)(logits_cls), dim=1)

                preds_train.append(labels_pred)
                labels_train.append(labels[:, 0])


            # find metrics in the end of the epoch
            predictions = torch.cat([preds for preds in preds_train])
            labels = torch.cat([labels for labels in labels_train])

            acc_train = self.accuracy(predictions, labels)
            prec_train = self.prec(predictions, labels)

            print(f"Epoch: {ii}")
            print(f"TRAIN | Loss: {np.mean(loss_batches): .3f}, Train_acc: {acc_train: .3f}, Train_prec: {prec_train: .3f}")

            # test
            with torch.no_grad():
                loss_batches_test = []
                preds_test = []
                labels_test = []
                for step, (images, labels) in enumerate(test_loader):
                    # images = images #.cuda()
                    # labels = labels #.cuda()
                    
                    self.cnn.eval()
                    # logits_cls = self.forward(images)
                    logits_cls = self.cnn(images.float())

                    loss = self.loss_ce(logits_cls, labels[:, 0])

                    # save loss
                    loss_batches_test.append(loss.item())

                    # predictions
                    labels_pred = torch.argmax(nn.Softmax(dim=1)(logits_cls), dim=1)

                    preds_test.append(labels_pred)
                    labels_test.append(labels[:, 0])

                # find metrics in the end of the epoch
                predictions = torch.cat([preds for preds in preds_test])
                labels = torch.cat([labels for labels in labels_test])

                acc_test = self.accuracy(predictions, labels)
                prec_test = self.prec(predictions, labels)

                print(f"TEST | Loss: {np.mean(loss_batches_test): .3f}, Test_acc: {acc_test: .3f}, Test_prec: {prec_test: .3f}")


In [None]:
cnn_model = ModelCIFAR() # .cuda()

train_loader, test_loader = DatamoduleCIFAR().create_loaders()

cnn_model.fit(train_loader, test_loader, num_epoch=10)

In [None]:
# Количество обучаемых параметров
sum(p.numel() for p in cnn_model.parameters() if p.requires_grad)