In [1]:
import sys
sys.path.append('..')
import torch
import torch.backends.cudnn as cudnn
import matplotlib.pyplot as plt
from src.training import train_model
from src.models import Resnet18_FC_Changed, Shufflenet_v2_x0_5, EfficientNetB0_FC_Changed
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
from torch.optim.lr_scheduler import StepLR
import os
cudnn.benchmark = True
plt.ion()

Device: cpu


<contextlib.ExitStack at 0x1abe708f250>

# Hyperparameters

In [2]:
T_0 = 1000 # Number of iterations for the first restart.
LEARNING_RATE=1e-4 # 0.0001
WEIGHT_DECAY=1e-8 # 0.000001
NUM_EPOCHS=3
BATCH_SIZE=32
MOMENTUM=0.09
num_of_classes = 64

In [3]:
from src.data_loader import GetDataLoaders, GetDataLoadersLimited

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

data_dir = '..\\data\\miniImageNet'
dataloaders, class_names, dataset_sizes  = GetDataLoadersLimited(data_dir, BATCH_SIZE, 20)

train_loader = dataloaders['train']
validation_loader = dataloaders['val']
test_loader = dataloaders['test']

In [4]:
print(" -- Training Set -- ")
print(len(train_loader.dataset.imgs))
print(len(train_loader.dataset.classes))
print(len(train_loader.dataset.imgs)/len(train_loader.dataset.classes))


print(" -- Validation Set -- ")
print(len(validation_loader.dataset.imgs))
print(len(validation_loader.dataset.classes))
print(len(validation_loader.dataset.imgs)/len(validation_loader.dataset.classes))

print(" -- Test Set -- ")
print(len(test_loader.dataset.imgs))
print(len(test_loader.dataset.classes))
print(len(test_loader.dataset.imgs)/len(test_loader.dataset.classes))

 -- Training Set -- 
38400
64
600.0
 -- Validation Set -- 
9600
16
600.0
 -- Test Set -- 
12000
20
600.0


In [5]:
net = Shufflenet_v2_x0_5(num_of_classes).to(device)
for p in net.parameters():
    p.requires_grad = True
net.train()

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(params= net.parameters(), lr=LEARNING_RATE, momentum=MOMENTUM, weight_decay=WEIGHT_DECAY)
scheduler = StepLR(optimizer, step_size=7, gamma=0.01)

print('Please wait patiently, it may take some seconds...')
best_model = train_model(net, dataloaders, criterion, optimizer, scheduler, NUM_EPOCHS, dataset_sizes)
save_path = '..\\data\\models\\best_model_Shufflenet_v2_x0_5.pth'
torch.save(best_model.state_dict(), save_path)

Please wait patiently, it may take some seconds...
Epoch 0/2
----------


In [None]:
from src.modelvis import visualize_model

visualize_model(best_model, dataloaders, num_images=6)

In [None]:
from src.models import Resnet18_FC_Changed, EfficientNetB0_FC_Changed
net = Resnet18_FC_Changed(num_of_classes).to(device)
for p in net.parameters():
    p.requires_grad = True
net.train()

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(params= net.parameters(), lr=LEARNING_RATE, momentum=MOMENTUM, weight_decay=WEIGHT_DECAY)
scheduler = StepLR(optimizer, step_size=7, gamma=0.01)

print('Please wait patiently, it may take some seconds...')
best_model = train_model(net, dataloaders, criterion, optimizer, scheduler, NUM_EPOCHS)
save_path = '..\\data\\models\\best_model_Resnet18.pth'
torch.save(best_model.state_dict(), save_path)

In [None]:
from src.modelvis import visualize_model

visualize_model(best_model, dataloaders, num_images=6)

In [None]:
net = EfficientNetB0_FC_Changed(num_of_classes).to(device)
for p in net.parameters():
    p.requires_grad = True
net.train()

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(params= net.parameters(), lr=LEARNING_RATE, momentum=MOMENTUM, weight_decay=WEIGHT_DECAY)
scheduler = StepLR(optimizer, step_size=7, gamma=0.01)

print('Please wait patiently, it may take some seconds...')
best_model = train_model(net, dataloaders, criterion, optimizer, scheduler, NUM_EPOCHS)
save_path = '..\\data\\models\\best_model_EfficientNetB0_FC_Changed.pth'
torch.save(best_model.state_dict(), save_path)

In [None]:
from src.modelvis import visualize_model

visualize_model(best_model, dataloaders, num_images=6)