In [1]:
import torch
import torchvision
import torchvision.transforms as transforms
from torchvision.datasets import CIFAR10
from torch.utils.data import random_split, DataLoader
from torch import nn, optim
from sklearn.metrics import accuracy_score, precision_score, f1_score
from tqdm import tqdm
import os
import numpy as np
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime

In [2]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Lambda(lambda x: x.view(-1))
])

dataset = CIFAR10(root='./data', train=True, download=True, transform=transform)

g = torch.Generator().manual_seed(42)

train_size = int(0.6 * len(dataset))
val_size = int(0.2 * len(dataset))
test_size = len(dataset) - train_size - val_size

trainset, valset, testset = random_split(dataset, [train_size, val_size, test_size],
                                         generator=g)

train_loader = DataLoader(trainset, batch_size=64, shuffle=True)
val_loader = DataLoader(valset, batch_size=64, shuffle=False)
test_loader = DataLoader(testset, batch_size=64, shuffle=False)

In [3]:
class Perceptron(nn.Module):
    def __init__(self,
                 input_dim: int = 32 * 32 * 3,
                 hidden_dim: int = 128, output_dim: int = 10
                ):
        super().__init__()

        self.activation = nn.ReLU()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        x = self.activation(self.fc1(x))
        x = self.activation(self.fc2(x))
        return x

In [4]:
class ModelTrainer:
    def __init__(
        self,
        model,
        device,
        optimizer,
        criterion,
        train_loader,
        val_loader,
        test_loader,
        epochs: int = 1,
        save_weights: bool =True,
        log_dir: str = "runs"
    ):
        self.model = model
        self.device = device
        self.model = self.model.to(self.device)
        self.optimizer = optimizer
        self.criterion = criterion
        
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.test_loader = test_loader
        
        self.epochs = epochs
        self.epoch = 1

        self.save_weights = save_weights

        self.current_time = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
        self.model_name = self.model.__class__.__name__
        self.log_dir = os.path.join("runs", f"{self.model_name}_PerceptronClf_{self.current_time}")

        self.writer = SummaryWriter(log_dir=self.log_dir)

        self.schedular = optim.lr_scheduler.StepLR(
            self.optimizer,
            step_size=10,
            gamma=0.5
        )
        
    def process_model(self):
        for epoch in range(self.epochs):
            train_loss, train_accuracy = self._train_model()
            val_loss, val_accuracy = self._validate_model()

            self.schedular.step()
            
            self.epoch += 1

        self.writer.close()
        return train_loss, val_loss, train_accuracy, val_accuracy

    def _train_model(self):
        self.model.train()
        total_loss = 0.0
        all_preds, all_targets = [], []

        for images, labels in tqdm(self.train_loader):
            images, labels = images.to(self.device), labels.to(self.device)

            outputs = self.model(images)
            loss = self.criterion(outputs, labels)
            
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

            total_loss += loss.item()

            preds = torch.argmax(outputs, dim=1)
            all_preds.extend(preds.cpu().numpy())
            all_targets.extend(labels.cpu().numpy())

        train_loss = total_loss / len(self.train_loader)
        accuracy = accuracy_score(all_targets, all_preds)
        precision = precision_score(all_targets, all_preds, average="weighted", zero_division=0)
        f1 = f1_score(all_targets, all_preds, average="weighted", zero_division=0)
        
        self._print_metrics("Train", train_loss, accuracy, precision, f1)
        
        self.writer.add_scalar("Loss/Train", train_loss, self.epoch)
        self.writer.add_scalar("Accuracy/Train", accuracy, self.epoch)
        self.writer.add_scalar("Precision/Train", precision, self.epoch)
        self.writer.add_scalar("F1/Train", f1, self.epoch)

        return train_loss, accuracy

    def _validate_model(self):
        self.model.eval()
        total_loss = 0.0
        all_preds, all_targets = [], []

        with torch.no_grad():
            for images, labels in tqdm(self.val_loader):
                images, labels = images.to(self.device), labels.to(self.device)

                outputs = self.model(images)
                loss = self.criterion(outputs, labels)
                
                total_loss += loss.item()

                preds = torch.argmax(outputs, dim=1)
                all_preds.extend(preds.cpu().numpy())
                all_targets.extend(labels.cpu().numpy())
            
        val_loss = total_loss / len(self.val_loader)
        accuracy = accuracy_score(all_targets, all_preds)
        precision = precision_score(all_targets, all_preds, average="weighted", zero_division=0)
        f1 = f1_score(all_targets, all_preds, average="weighted", zero_division=0)
        
        self._print_metrics("Val", val_loss, accuracy, precision, f1)
        
        self.writer.add_scalar("Loss/Val", val_loss, self.epoch)
        self.writer.add_scalar("Accuracy/Val", accuracy, self.epoch)
        self.writer.add_scalar("Precision/Val", precision, self.epoch)
        self.writer.add_scalar("F1/Val", f1, self.epoch)

        return val_loss, accuracy

    def test_model(self):
        self.model.eval()
        total_loss = 0.0
        all_preds, all_targets = [], []

        with torch.no_grad():
            for images, labels in tqdm(self.test_loader):
                images, labels = images.to(self.device), labels.to(self.device)

                outputs = self.model(images)
                loss = self.criterion(outputs, labels)
                
                total_loss += loss.item()

                preds = torch.argmax(outputs, dim=1)
                all_preds.extend(preds.cpu().numpy())
                all_targets.extend(labels.cpu().numpy())
                
        test_loss = total_loss / len(self.test_loader)
        accuracy = accuracy_score(all_targets, all_preds)
        precision = precision_score(all_targets, all_preds, average="weighted", zero_division=0)
        f1 = f1_score(all_targets, all_preds, average="weighted", zero_division=0)
        
        self._print_metrics("Test", test_loss, accuracy, precision, f1)

        if self.save_weights:
            torch.save(self.model.state_dict(), 'densora_gru_weights.pth')

        return test_loss, accuracy

    def _print_metrics(self, phase: str, loss, accuracy, precision, f1):
        if phase == "Test":
            print(f"{phase} Loss {self.epoch - 1}: {loss:.4f} | Accuracy: {accuracy:.4f}")
        else:
            print(f"{phase} Loss {self.epoch}: {loss:.4f} | Accuracy: {accuracy:.4f}")

        print(f"Precision: {precision:.4f}")
        print(f"F1 score: {f1:.4f}")

In [6]:
model = Perceptron()
optimizer = optim.SGD(model.parameters(), lr=0.01)
criterion = nn.CrossEntropyLoss()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

trainer = ModelTrainer(
    model=model,
    device=device,
    optimizer=optimizer,
    criterion=criterion,
    train_loader=train_loader,
    val_loader=val_loader,
    test_loader=test_loader,
    epochs=20,   
    save_weights=False
)

train_loss, val_loss, train_acc, val_acc = trainer.process_model()
print(f"Final Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")
print(f"Final Train Acc: {train_acc:.4f}, Val Acc: {val_acc:.4f}")

test_loss, test_acc = trainer.test_model()
print(f"Test Loss: {test_loss:.4f}, Test Accuracy: {test_acc:.4f}")

100%|███████████████████████████████████████████████████████████████████████████| 469/469 [00:07<00:00, 64.85it/s]


Train Loss 1: 2.2258 | Accuracy: 0.2043
Precision: 0.1293
F1 score: 0.1374


100%|██████████████████████████████████████████████████████████████████████████| 157/157 [00:01<00:00, 102.11it/s]


Val Loss 1: 2.1584 | Accuracy: 0.2541
Precision: 0.1751
F1 score: 0.1802


100%|██████████████████████████████████████████████████████████████████████████| 469/469 [00:04<00:00, 107.88it/s]


Train Loss 2: 2.1258 | Accuracy: 0.2707
Precision: 0.1799
F1 score: 0.2019


100%|██████████████████████████████████████████████████████████████████████████| 157/157 [00:01<00:00, 127.52it/s]


Val Loss 2: 2.0899 | Accuracy: 0.2861
Precision: 0.2207
F1 score: 0.2284


100%|███████████████████████████████████████████████████████████████████████████| 469/469 [00:04<00:00, 97.60it/s]


Train Loss 3: 2.0501 | Accuracy: 0.3039
Precision: 0.2368
F1 score: 0.2559


100%|██████████████████████████████████████████████████████████████████████████| 157/157 [00:01<00:00, 116.62it/s]


Val Loss 3: 2.0206 | Accuracy: 0.3132
Precision: 0.2628
F1 score: 0.2651


100%|██████████████████████████████████████████████████████████████████████████| 469/469 [00:04<00:00, 101.76it/s]


Train Loss 4: 1.9641 | Accuracy: 0.3293
Precision: 0.2881
F1 score: 0.3013


100%|██████████████████████████████████████████████████████████████████████████| 157/157 [00:01<00:00, 119.57it/s]


Val Loss 4: 1.9202 | Accuracy: 0.3421
Precision: 0.3136
F1 score: 0.3161


100%|██████████████████████████████████████████████████████████████████████████| 469/469 [00:04<00:00, 102.90it/s]


Train Loss 5: 1.9062 | Accuracy: 0.3454
Precision: 0.3061
F1 score: 0.3203


100%|██████████████████████████████████████████████████████████████████████████| 157/157 [00:01<00:00, 121.45it/s]


Val Loss 5: 1.8942 | Accuracy: 0.3400
Precision: 0.3217
F1 score: 0.3051


100%|██████████████████████████████████████████████████████████████████████████| 469/469 [00:04<00:00, 105.09it/s]


Train Loss 6: 1.8346 | Accuracy: 0.3647
Precision: 0.3598
F1 score: 0.3554


100%|██████████████████████████████████████████████████████████████████████████| 157/157 [00:01<00:00, 125.95it/s]


Val Loss 6: 1.7956 | Accuracy: 0.3680
Precision: 0.3874
F1 score: 0.3623


100%|██████████████████████████████████████████████████████████████████████████| 469/469 [00:04<00:00, 104.99it/s]


Train Loss 7: 1.7627 | Accuracy: 0.3864
Precision: 0.3813
F1 score: 0.3819


100%|██████████████████████████████████████████████████████████████████████████| 157/157 [00:01<00:00, 118.32it/s]


Val Loss 7: 1.7637 | Accuracy: 0.3816
Precision: 0.4043
F1 score: 0.3740


100%|██████████████████████████████████████████████████████████████████████████| 469/469 [00:04<00:00, 104.56it/s]


Train Loss 8: 1.7338 | Accuracy: 0.3969
Precision: 0.3916
F1 score: 0.3921


100%|██████████████████████████████████████████████████████████████████████████| 157/157 [00:01<00:00, 120.08it/s]


Val Loss 8: 1.7233 | Accuracy: 0.3992
Precision: 0.4021
F1 score: 0.3948


100%|██████████████████████████████████████████████████████████████████████████| 469/469 [00:04<00:00, 105.48it/s]


Train Loss 9: 1.7079 | Accuracy: 0.4047
Precision: 0.3982
F1 score: 0.3994


100%|██████████████████████████████████████████████████████████████████████████| 157/157 [00:01<00:00, 120.85it/s]


Val Loss 9: 1.7074 | Accuracy: 0.4014
Precision: 0.4062
F1 score: 0.3924


100%|██████████████████████████████████████████████████████████████████████████| 469/469 [00:04<00:00, 104.07it/s]


Train Loss 10: 1.6833 | Accuracy: 0.4122
Precision: 0.4058
F1 score: 0.4068


100%|██████████████████████████████████████████████████████████████████████████| 157/157 [00:01<00:00, 111.39it/s]


Val Loss 10: 1.7177 | Accuracy: 0.3918
Precision: 0.4132
F1 score: 0.3814


100%|██████████████████████████████████████████████████████████████████████████| 469/469 [00:04<00:00, 105.94it/s]


Train Loss 11: 1.6513 | Accuracy: 0.4259
Precision: 0.4193
F1 score: 0.4205


100%|██████████████████████████████████████████████████████████████████████████| 157/157 [00:01<00:00, 115.10it/s]


Val Loss 11: 1.6727 | Accuracy: 0.4161
Precision: 0.4185
F1 score: 0.4101


100%|██████████████████████████████████████████████████████████████████████████| 469/469 [00:04<00:00, 103.34it/s]


Train Loss 12: 1.6402 | Accuracy: 0.4317
Precision: 0.4251
F1 score: 0.4262


100%|██████████████████████████████████████████████████████████████████████████| 157/157 [00:01<00:00, 118.00it/s]


Val Loss 12: 1.6631 | Accuracy: 0.4144
Precision: 0.4172
F1 score: 0.4094


100%|██████████████████████████████████████████████████████████████████████████| 469/469 [00:04<00:00, 101.36it/s]


Train Loss 13: 1.6307 | Accuracy: 0.4318
Precision: 0.4248
F1 score: 0.4262


100%|██████████████████████████████████████████████████████████████████████████| 157/157 [00:01<00:00, 116.59it/s]


Val Loss 13: 1.6522 | Accuracy: 0.4238
Precision: 0.4214
F1 score: 0.4197


100%|██████████████████████████████████████████████████████████████████████████| 469/469 [00:04<00:00, 104.15it/s]


Train Loss 14: 1.6203 | Accuracy: 0.4350
Precision: 0.4283
F1 score: 0.4294


100%|██████████████████████████████████████████████████████████████████████████| 157/157 [00:01<00:00, 116.78it/s]


Val Loss 14: 1.6550 | Accuracy: 0.4243
Precision: 0.4360
F1 score: 0.4184


100%|██████████████████████████████████████████████████████████████████████████| 469/469 [00:04<00:00, 104.37it/s]


Train Loss 15: 1.6102 | Accuracy: 0.4417
Precision: 0.4351
F1 score: 0.4363


100%|██████████████████████████████████████████████████████████████████████████| 157/157 [00:01<00:00, 113.56it/s]


Val Loss 15: 1.6479 | Accuracy: 0.4202
Precision: 0.4354
F1 score: 0.4198


100%|██████████████████████████████████████████████████████████████████████████| 469/469 [00:04<00:00, 103.37it/s]


Train Loss 16: 1.6024 | Accuracy: 0.4439
Precision: 0.4367
F1 score: 0.4381


100%|██████████████████████████████████████████████████████████████████████████| 157/157 [00:01<00:00, 123.85it/s]


Val Loss 16: 1.6287 | Accuracy: 0.4302
Precision: 0.4272
F1 score: 0.4251


100%|██████████████████████████████████████████████████████████████████████████| 469/469 [00:04<00:00, 105.22it/s]


Train Loss 17: 1.5934 | Accuracy: 0.4460
Precision: 0.4392
F1 score: 0.4403


100%|██████████████████████████████████████████████████████████████████████████| 157/157 [00:01<00:00, 123.87it/s]


Val Loss 17: 1.6292 | Accuracy: 0.4294
Precision: 0.4363
F1 score: 0.4264


100%|██████████████████████████████████████████████████████████████████████████| 469/469 [00:04<00:00, 102.44it/s]


Train Loss 18: 1.5830 | Accuracy: 0.4536
Precision: 0.4470
F1 score: 0.4482


100%|██████████████████████████████████████████████████████████████████████████| 157/157 [00:01<00:00, 123.28it/s]


Val Loss 18: 1.6279 | Accuracy: 0.4261
Precision: 0.4349
F1 score: 0.4171


100%|███████████████████████████████████████████████████████████████████████████| 469/469 [00:05<00:00, 92.76it/s]


Train Loss 19: 1.5747 | Accuracy: 0.4543
Precision: 0.4476
F1 score: 0.4488


100%|██████████████████████████████████████████████████████████████████████████| 157/157 [00:01<00:00, 113.77it/s]


Val Loss 19: 1.6211 | Accuracy: 0.4311
Precision: 0.4414
F1 score: 0.4272


100%|███████████████████████████████████████████████████████████████████████████| 469/469 [00:04<00:00, 95.70it/s]


Train Loss 20: 1.5675 | Accuracy: 0.4566
Precision: 0.4496
F1 score: 0.4510


100%|██████████████████████████████████████████████████████████████████████████| 157/157 [00:01<00:00, 103.97it/s]


Val Loss 20: 1.6041 | Accuracy: 0.4379
Precision: 0.4340
F1 score: 0.4289
Final Train Loss: 1.5675, Val Loss: 1.6041
Final Train Acc: 0.4566, Val Acc: 0.4379


100%|██████████████████████████████████████████████████████████████████████████| 157/157 [00:01<00:00, 117.56it/s]

Test Loss 20: 1.5961 | Accuracy: 0.4425
Precision: 0.4359
F1 score: 0.4324
Test Loss: 1.5961, Test Accuracy: 0.4425



