In [None]:
import numpy as np
import torch
from torch import nn, optim
from torchvision import transforms, datasets, models
from PIL import Image
import json
import matplotlib.pyplot as plt
%matplotlib inline
from tqdm import trange, tqdm

In [None]:
data_dir = '/Users/js/programs/datasets/flower_data'
train_dir = data_dir + '/train'
test_dir = data_dir + '/valid'

In [None]:
imagenet_mean = [0.485, 0.456, 0.406]
imagenet_std = [0.229, 0.224, 0.225]

data_transforms = transforms.Compose([
    transforms.Resize(256),
    transforms.TenCrop(224),
    transforms.Lambda(lambda crops: torch.stack([transforms.ToTensor()(crop) for crop in crops])),
    transforms.Lambda(lambda crops: torch.stack([transforms.Normalize(mean=imagenet_mean, std=imagenet_std)(crop) for crop in crops]))
])

train_dat = datasets.ImageFolder(train_dir, transform=data_transforms)
test_dat = datasets.ImageFolder(test_dir, transform=data_transforms)

train_loader = torch.utils.data.DataLoader(train_dat, batch_size=8, shuffle=True, num_workers=4)
test_loader = torch.utils.data.DataLoader(val_dat, batch_size=8, shuffle=True, num_workers=4)

In [None]:
model = models.resnet50(pretrained=True)
model.fc

In [None]:
class testnet(nn.Module):
    def __init__(self):
        super().__init__()        
        self.fc1 = nn.Linear(2048, 8192)
        self.fc2 = nn.Linear(8192, 4096)
        self.fc3 = nn.Linear(4096, 1024) 
        self.fc4 = nn.Linear(1024, 102)
        
        self.drop = nn.Dropout(p=0.5)
        self.relu = nn.LeakyReLU()

    def forward(self, x):
        x = self.drop(self.relu(self.fc1(x)))
        x = self.drop(self.relu(self.fc2(x)))
        x = self.drop(self.relu(self.fc3(x)))
        x = self.fc4(x)
        return x

In [None]:
clf = testnet()
model.fc = clf
if torch.cuda.is_available():
    model = model.cuda()

In [None]:
for param in model.parameters():
    param.requires_grad = False
model.fc.requires_grad = True

optimizer = optim.Adam(model.parameters(), lr=1e-5)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, len(train_loader))
loss_fn = nn.CrossEntropyLoss()

In [None]:
train_losses_1 = []
test_losses_1 = []
accs = []

In [None]:
epochs = 10

for e in trange(epochs, desc='Total training'):
    model.train()
    train_loss = 0
    for feature, target in tqdm(train_loader, desc=f'Training for epoch {e+1}'):
        if torch.cuda.is_available():
            feature, target = feature.cuda(), target.cuda()
        optimizer.zero_grad()
        bs, ncrops, c, h, w = feature.shape
        flattened = feature.view(-1, c, h, w)
        output = model(flattened)
        output = output.view(bs, ncrops, -1).mean(1)
        loss = loss_fn(output, target)

        loss.backward()
        optimizer.step()
        scheduler.step()
        train_loss += loss.item()

    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, len(train_loader))

    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for feature, target in tqdm(test_loader, desc=f'Testing for epoch {e+1}'):
            if torch.cuda.is_available():
                feature, target = feature.cuda(), target.cuda()
            bs, ncrops, c, h, w = feature.shape
            flattened = feature.view(-1, c, h, w)
            output = model(flattened)
            output = output.view(bs, ncrops, -1).mean(1)
            loss = loss_fn(output, target)
            test_loss += loss.item()
            preds = output.max(dim=1)[1]
            correct += (preds == label).sum().item()

    accuracy = correct/len(test_dat)
    
    train_losses_1.append(train_loss)
    test_losses_1.append(test_loss)
    accs.append(accuracy)

    print(f'Epoch {e+1} | Training loss: {train_loss} | Testing loss: {test_loss} | Accuracy: {accuracy*100:.3f}%')

    if len(test_losses_1) == 1:
        torch.save(model, f'flower.pt')
    else:
        # if this has the lowest test loss then overwrite save 
        if test_loss <= min(test_losses_1):
            torch.save(model, f'flower.pt')

            
# at end of training, load the best performing model
model = torch.load('flower.pt', map_location=device)
model.to(device)

In [1]:
fig, axs = plt.subplots(1, 2)
axs[0].plot(train_losses_1)
axs[0].plot(test_losses_1)
axs[1].plot(accs)

NameError: name 'plt' is not defined