In [None]:
%pip install torchmetrics

In [None]:
import sys
import random
from PIL import Image
import numpy as np
import torch
import torch.nn as nn
import torchvision
from datasets import load_dataset, ClassLabel
from torchvision import transforms
from tqdm import tqdm
from torchmetrics.classification import MulticlassF1Score

########################################
# 1. Фиксация seed
########################################
def seed_everything(seed=42):
    random.seed(seed)
    np.random.seed(seed)

    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    print(f"[ OK ] global seed = {seed}")

seed = 42
seed_everything(seed)

# For DataLoader reproducibility
def seed_worker(worker_id):
    worker_seed = (seed + worker_id) % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)

g = torch.Generator()
g.manual_seed(seed)

########################################
# Dataset + transforms
########################################
tiny_imagenet = load_dataset('Maysee/tiny-imagenet')
tiny_imagenet = tiny_imagenet.cast_column("label", ClassLabel(num_classes=200))


def train_transform_function(examples):
    transform = transforms.Compose([
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.ColorJitter(brightness=(0.9, 1.08), contrast=(0.9, 1.08)),
        transforms.RandomResizedCrop(64, scale=(0.8, 0.95)),
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225]
        )
    ])

    examples['image'] = [transform(img.convert("RGB")) for img in examples['image']]
    return examples


def val_transform_function(examples):
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225]
        )
    ])

    examples['image'] = [transform(img.convert("RGB")) for img in examples['image']]
    return examples


train_tiny_imagenet = tiny_imagenet['train'].with_transform(train_transform_function)
val_tiny_imagenet = tiny_imagenet['valid'].with_transform(val_transform_function)


########################################
# DataLoaders with deterministic workers
########################################
train_loader = torch.utils.data.DataLoader(
    train_tiny_imagenet,
    batch_size=128,
    shuffle=True,
    num_workers=28,
    pin_memory=True,
    worker_init_fn=seed_worker,
    generator=g
)

val_loader = torch.utils.data.DataLoader(
    val_tiny_imagenet,
    batch_size=128,
    shuffle=False,
    num_workers=28,
    pin_memory=True,
    worker_init_fn=seed_worker,
    generator=g
)

########################################
# Device
########################################
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
def testing(model_name='alexnet',
            l_rate=1e-4,
            w_decay=5e-5,
            log_file='training_log.txt'):

    original_stdout = sys.stdout
    with open(log_file, 'a') as f:
        sys.stdout = f

        print(f"\n\n\nlr = {l_rate}, weight_decay = {w_decay}")

        if model_name == 'alexnet':
            model = torchvision.models.alexnet()  #weights='DEFAULT'
            model.classifier[6] = nn.Linear(4096, 200)
        else:
            model = torchvision.models.resnet18()   #weights='DEFAULT'
            model.fc = nn.Linear(512, 200)

        model = model.to(DEVICE)

        criterion = nn.CrossEntropyLoss()
        optimizer = torch.optim.SGD(model.parameters(), lr=l_rate, momentum=0.9, weight_decay=w_decay)
        f1_metric = MulticlassF1Score(num_classes=200, average="macro").to(DEVICE)

        n_epochs = 80

        for epoch in range(n_epochs):
            model.train()

            train_loss = 0.0
            train_acc = 0.0
            f1_metric.reset()

            for batch in tqdm(train_loader):
                data = batch['image'].to(DEVICE)
                target = batch['label'].to(DEVICE)
                optimizer.zero_grad()
                output = model(data)
                loss = criterion(output, target)
                loss.backward()
                optimizer.step()
                train_loss += loss.item() * data.size(0)
                train_acc += (output.argmax(1) == target).sum().item()

            train_loss = train_loss / len(train_loader.dataset)
            train_acc = train_acc / len(train_loader.dataset)

            model.eval()

            val_acc = 0.0
            val_acc_top5 = 0.0

            with torch.no_grad():
                for batch in tqdm(val_loader):
                    data = batch['image'].to(DEVICE)
                    target = batch['label'].to(DEVICE)
                    output = model(data)
                    val_acc += (output.argmax(1) == target).sum().item()
                    _, pred_top5 = output.topk(5, dim=1)
                    val_acc_top5 += sum(target[i].item() in pred_top5[i] for i in range(target.size(0)))
                    f1_metric.update(output, target)


            val_acc = val_acc / len(val_loader.dataset)
            val_acc_top5 /= len(val_loader.dataset)
            val_f1 = f1_metric.compute().item()


            print(
                f"Epoch: {epoch+1} "
                f"Train Loss: {train_loss:.6f} "
                f"Train Acc: {train_acc:.6f} "
                f"Val Top1: {val_acc:.6f} "
                f"Val Top5: {val_acc_top5:.6f} "
                f"F1: {val_f1:.6f} "
            )

        sys.stdout = original_stdout

In [None]:
for lr in [1e-5, 1e-4, 1e-3, 1e-2]:
    for wd in [5e-5, 5e-4. 5e-3, 5e-2]:
        testing(model_name='alexnet',
                l_rate=lr,
                w_decay=wd,
                log_file='AlexNet_without_pretrain_upd.txt'
                )

for lr in [1e-5, 1e-4, 1e-3, 1e-2]:
    for wd in [5e-5, 5e-4, 5e-3, 5e-2]:
        testing(model_name='resnet18',
                l_rate=lr,
                w_decay=wd,
                log_file='ResNet_without_pretrain_upd.txt'
                )