In [None]:
import torch
from ResNet import ResNet
from trainer import Trainer
from torchvision.datasets import ImageFolder
from util import count_parameters, show, test
from torch.utils.data import random_split, DataLoader
from torchvision.transforms import Compose, Resize, Normalize, ToTensor

In [None]:
data_path = './datasets/animals10'

T = Compose([
    Resize((64,64)),
    ToTensor(),
    Normalize(mean=[.19,.19,.20],std=[.51,.50,.41]),
])

dataset = ImageFolder(data_path, transform=T)

train_set, val_set, test_set = random_split(dataset, [.6,.2,.2])
print(f'train: {len(train_set)}' + '\n' + f'valid: {len(val_set)}')
  
loaders = {
    'train': DataLoader(train_set, shuffle=True,  batch_size=64, drop_last=True),
    'val':   DataLoader(val_set,   shuffle=False, batch_size=64),
    'test':  DataLoader(test_set,  shuffle=False, batch_size=64)
}

In [None]:
show(loaders['train'],4,4,dataset.classes)

In [None]:
model = ResNet(num_classes=10, num_blocks=[3,3,3], c_hidden=[32,64,128], dropout_prob=0.17)
count_parameters(model)

In [None]:
trainer = Trainer(model, loaders['train'], loaders['test'])
trainer.start(25)

In [None]:
model = ResNet(num_classes=10, num_blocks=[3,3,3], c_hidden=[32,64,128], dropout_prob=0.33)
model.to(torch.device('mps'))
model.load_state_dict(torch.load('./checkpoint.pth'))

trainer.model = model
trainer.optimizer= torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=2e-1)

In [None]:
trainer.start(7, clear=False)

In [None]:
trainer.optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
trainer.start(7, clear=False)

In [None]:
trainer.optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)
trainer.start(4, clear=False)

In [None]:
model = ResNet(num_classes=10, num_blocks=[3,3,3], c_hidden=[32,64,128], dropout_prob=0.1)
model.to(torch.device('mps'))
model.load_state_dict(torch.load('./checkpoint.pth'))

trainer.model = model
trainer.optimizer= torch.optim.Adam(model.parameters())
trainer.start(7, clear=False)

In [None]:
trainer.plot_history()

In [None]:
model.load_state_dict(torch.load('./checkpoint.pth'))
trainer.model.load_state_dict(torch.load('./checkpoint.pth'))

In [None]:
trainer.validate(loaders['val'])

In [None]:
test(model, loaders['val'], dataset.classes)