In [1]:
import torch
from torchvision.datasets import Food101
from torch.utils.data import DataLoader
from matplotlib import pyplot as plt
from torchvision.transforms import v2 as tv2
from torch.utils.data import random_split
from torchvision.models import resnet18
from torch import nn
from torch.optim import Adam
from tqdm import tqdm
from pathlib import Path
from torchvision.models import ResNet

In [2]:
BATCH_SIZE = 128

In [3]:
train_transform = tv2.Compose([
    tv2.Resize(256),
    tv2.RandomCrop(224),
    tv2.ToImage(),
    tv2.ToDtype(torch.float32, scale=True),
    tv2.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
])

test_transform = tv2.Compose([
    tv2.Resize(256),
    tv2.CenterCrop(224),
    tv2.ToImage(),
    tv2.ToDtype(torch.float32, scale=True),
    tv2.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
])
demonstr_transform = tv2.Compose([
    tv2.Resize(256),
    tv2.CenterCrop(224),
    tv2.ToImage()
])

train_ds = Food101(root="data", split="train", download=True, transform=train_transform)
val_ds, test_ds = random_split(Food101(root="data", split="test", download=True,  transform=test_transform), [0.2, 0.8], generator=torch.Generator().manual_seed(42))
demonstration_ds = Food101(root="data", split="test", download=True, transform=demonstr_transform)
train_dl = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True,  num_workers=2, pin_memory=True)
val_dl = DataLoader(val_ds, batch_size=BATCH_SIZE)
test_dl = DataLoader(test_ds, batch_size=BATCH_SIZE)

In [4]:
classes_li = demonstration_ds.classes
# fig = plt.figure(figsize=(8, 8))
# cols, rows = 4, 4
# for i in range(1, cols*rows+1):
#     sample_idx = torch.randint(len(demonstration_ds), size=(1,))
#     img, label = demonstration_ds[sample_idx]
#     fig.add_subplot(rows, cols, i)
#     plt.imshow(img)
#     plt.axis("off")
#     plt.title(f"{classes_li[label]}\n{img.size}")

In [5]:
def train_model(num_epochs: int, model: ResNet, train_data_loader, validation_data_loader, optimizer, loss_function, device, save_directory: str = None, model_name: str = None, start_epoch: int = 0):
    for epoch in range(start_epoch+1, start_epoch+num_epochs+1):
        train_bar = tqdm(train_data_loader, desc=f"Training Epoch: {epoch}/{num_epochs+start_epoch}", leave=False, unit="batch")
        running_loss = 0
        for imgs, labels in train_bar:
            imgs, labels = imgs.to(device), labels.to(device)
            output = model(imgs)
            loss = loss_fun(output, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            loss_item = loss.item()
            running_loss += loss_item*train_dl.batch_size
            train_bar.set_postfix(loss=f"{loss_item:.4f}")

        if save_directory is not None:
            Path(save_directory).mkdir(exist_ok=True)
            if model_name is not None:
                torch.save(model, Path(save_directory) / (model_name+f"_epoch_{epoch+start_epoch}.pth"))
            else:
                torch.save(model, Path(save_directory) / (f"model_epoch_{epoch+start_epoch}.pth"))
        model.eval()
        val_bar = tqdm(validation_data_loader, desc="Validation", leave=False, unit="batch")
        running_acc = 0
        with torch.no_grad():
            for imgs, labels in val_bar:
                imgs, labels = imgs.to(device), labels.to(device)
                output = model(imgs)
                pred = output.argmax(dim=1)
                summ = (pred == labels).sum().item()
                running_acc += summ
                val_bar.set_postfix(accurancy = f"{summ/imgs.shape[0]:.4f}")
        print("\r" + f"Epoch: {epoch}/{num_epochs+start_epoch}, loss = {running_loss/len(train_dl.dataset):.4f}, accurancy = {running_acc/len(val_ds):.4f}")

In [6]:
device = torch.accelerator.current_accelerator().type if torch.accelerator.is_available() else "cpu"
print(f"using: {device}")
model = resnet18("DEFAULT")

in_features = model.fc.in_features
model.fc = nn.Linear(in_features, 101)
model.to(device)

loss_fun = nn.CrossEntropyLoss()

using: cuda




In [7]:
for param in model.parameters():
    param.requires_grad = False

for param in model.fc.parameters():
    param.requires_grad = True

model.train()
for m in model.modules():
    if isinstance(m, nn.BatchNorm2d):
        m.eval()
optimizer = Adam(model.fc.parameters(), lr=4e-3)
train_model(2, model, train_dl, val_dl, optimizer, loss_fun, device, "/home/oedada/Projects/experiments/NTO/Final/neuronki/data/checkpoints/for_best", "food101")

                                                                                      

Epoch: 1/2, loss = 2.2611, accurancy = 0.5638


                                                                                      

Epoch: 2/2, loss = 1.8868, accurancy = 0.5650




In [8]:
model = torch.load("/home/oedada/Projects/experiments/NTO/Final/neuronki/data/checkpoints/for_best/food101_epoch_2.pth", weights_only=False)
model.train()
for m in model.modules():
    if isinstance(m, nn.BatchNorm2d):
        m.eval()
for param in model.layer3.parameters():
    param.requires_grad = True
for param in model.layer4.parameters():
    param.requires_grad = True

optimizer = Adam([
    {"params": model.layer3.parameters(), "lr": 1e-5},
    {"params": model.layer4.parameters(), "lr": 2e-5},
    {"params": model.fc.parameters(), "lr": 3e-4}
])
train_model(13, model, train_dl, val_dl, optimizer, loss_fun, device, "/home/oedada/Projects/experiments/NTO/Final/neuronki/data/checkpoints/for_best", "food101", start_epoch=2)

                                                                                       

Epoch: 3/15, loss = 1.4649, accurancy = 0.6634


                                                                                       

Epoch: 4/15, loss = 1.2658, accurancy = 0.6818


                                                                                       

Epoch: 5/15, loss = 1.1358, accurancy = 0.6949


                                                                                       

Epoch: 6/15, loss = 1.0301, accurancy = 0.7000


                                                                                       

Epoch: 7/15, loss = 0.9355, accurancy = 0.7125


                                                                                       

Epoch: 8/15, loss = 0.8523, accurancy = 0.7170


                                                                                       

Epoch: 9/15, loss = 0.7794, accurancy = 0.7263


                                                                                        

Epoch: 10/15, loss = 0.7090, accurancy = 0.7271


                                                                                        

Epoch: 11/15, loss = 0.6460, accurancy = 0.7313


                                                                                        

Epoch: 12/15, loss = 0.5838, accurancy = 0.7253


                                                                                        

Epoch: 13/15, loss = 0.5274, accurancy = 0.7293


                                                                                        

Epoch: 14/15, loss = 0.4788, accurancy = 0.7271


                                                                                        

Epoch: 15/15, loss = 0.4323, accurancy = 0.7188


