In [None]:
import n2d2
import torch
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
import torchvision.datasets
from neurocorgi_sdk.models import NeuroCorgiNet_torch

In [None]:
# If possible, set up the GPU 0 for the application
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

n2d2.global_variables.default_model = "Frame_CUDA" if n2d2.global_variables.cuda_available else "Frame"
n2d2.global_variables.cuda_device = 0

n2d2.global_variables.verbosity = 0

In [None]:

train_transforms = transforms.Compose([transforms.Resize((224,224)),
                                       transforms.ToTensor(),
                                       ])

test_transforms = transforms.Compose([transforms.Resize((224, 224)),
                                      transforms.ToTensor(),
                                      ])

dataset_train = torchvision.datasets.CIFAR100(root="./data", train=True, download=True, transform=train_transforms)
dataset_test = torchvision.datasets.CIFAR100(root="./data", train=False, download=True, transform=test_transforms)

train_loader = torch.utils.data.DataLoader(dataset_train, batch_size=128, shuffle=True, drop_last=True)
test_loader = torch.utils.data.DataLoader(dataset_test, batch_size=128, shuffle=False, drop_last=True)

In [None]:
def im_convert(tensor):
  image = tensor.cpu().clone().detach().numpy()
  image = image.transpose(1, 2, 0)
  image = image.clip(0, 1)
  return image

In [None]:
classes = train_loader.dataset.classes
print(classes)

In [None]:
dataiter = iter(train_loader)
images, labels = next(dataiter)
fig = plt.figure(figsize=(25, 4))

for idx in np.arange(20):
  ax = fig.add_subplot(2, 10, idx+1, xticks=[], yticks=[])
  plt.imshow(im_convert(images[idx]))
  ax.set_title(classes[labels[idx].item()])

In [None]:
# For this example, we use the model pretrained and quantized with the ImageNet dataset
extractor = NeuroCorgiNet_torch([128, 3, 224, 224], weights_dir="data/imagenet_weights")
extractor.to(device)

In [None]:
pool = torch.nn.AvgPool2d(7, stride=7)
flatten = torch.nn.Flatten(start_dim=1)
classifier = torch.nn.Linear(1024, 100)
head = torch.nn.Sequential(pool, flatten, classifier)
head.to(device)

In [None]:
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(head.parameters(), lr=0.01)

In [None]:
extractor = extractor.eval()
head = head.train()

In [None]:
epochs = 15
running_loss_history = []
running_corrects_history = []
val_running_loss_history = []
val_running_corrects_history = []

for e in range(epochs):
  
  running_loss = 0.0
  running_corrects = 0.0
  val_running_loss = 0.0
  val_running_corrects = 0.0
  
  for i, (inputs, labels) in enumerate(tqdm(train_loader)):
    inputs = inputs.to(device)
    labels = labels.to(device)

    conv3_1x1, conv5_1x1, conv7_5_1x1, conv9_1x1 = extractor(inputs)
    outputs = head(conv9_1x1)

    loss = criterion(outputs, labels)
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    _, preds = torch.max(outputs, 1)
    running_loss += loss.item()
    running_corrects += torch.sum(preds == labels.data)

  else:
      with torch.no_grad():
        for i, (val_inputs, val_labels) in enumerate(tqdm(test_loader)):
          val_inputs = val_inputs.to(device)
          val_labels = val_labels.to(device)

          conv3_1x1, conv5_1x1, conv7_5_1x1, conv9_1x1 = extractor(val_inputs)
          val_outputs = head(conv9_1x1)

          val_loss = criterion(val_outputs, val_labels)
          
          _, val_preds = torch.max(val_outputs, 1)
          val_running_loss += val_loss.item()
          val_running_corrects += torch.sum(val_preds == val_labels.data)
        
      epoch_loss = running_loss/len(train_loader.dataset)
      epoch_acc = running_corrects.float()/ len(train_loader.dataset)
      running_loss_history.append(epoch_loss)
      running_corrects_history.append(epoch_acc)
      
      val_epoch_loss = val_running_loss/len(test_loader.dataset)
      val_epoch_acc = val_running_corrects.float()/ len(test_loader.dataset)
      val_running_loss_history.append(val_epoch_loss)
      val_running_corrects_history.append(val_epoch_acc)
      print('epoch :', (e+1))
      print('training loss: {:.4f}, acc {:.4f} '.format(epoch_loss, epoch_acc.item()))
      print('validation loss: {:.4f}, validation acc {:.4f} '.format(val_epoch_loss, val_epoch_acc.item()))

In [None]:
plt.plot(running_loss_history, label='training loss')
plt.plot(val_running_loss_history, label='validation loss')
plt.legend()

In [None]:
running_corrects_history = [x.to("cpu") for x in running_corrects_history]
val_running_corrects_history = [x.to("cpu") for x in val_running_corrects_history]

plt.plot(running_corrects_history, label='training accuracy')
plt.plot(val_running_corrects_history, label='validation accuracy')
plt.legend()

In [None]:
extractor = extractor.eval()
head = head.eval()

In [None]:
running_corrects = 0.0
  
for inputs, labels in test_loader:
    inputs = inputs.to(device)
    labels = labels.to(device)
    
    conv3_1x1, conv5_1x1, conv7_5_1x1, conv9_1x1 = extractor(inputs)
    outputs = head(conv9_1x1)

    _, preds = torch.max(outputs, 1)
    running_corrects += torch.sum(preds == labels.data)

epoch_acc = running_corrects.float()/ len(train_loader)
print(epoch_acc)

In [None]:
dataiter = iter(test_loader)
images, labels = next(dataiter)
images = images.to(device)
labels = labels.to(device)

conv3_1x1, conv5_1x1, conv7_5_1x1, conv9_1x1 = extractor(images)
output = head(conv9_1x1)
_, preds = torch.max(output, 1)

fig = plt.figure(figsize=(25, 4))

for idx in np.arange(20):
  ax = fig.add_subplot(2, 10, idx+1, xticks=[], yticks=[])
  plt.imshow(im_convert(images[idx]))
  ax.set_title("{} ({})".format(str(classes[preds[idx].item()]), str(classes[labels[idx].item()])), color=("green" if preds[idx]==labels[idx] else "red"))