In [None]:
import torch
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
from torchvision import transforms, datasets
from neurocorgi_sdk import NeuroCorgiNet
from neurocorgi_sdk.transforms import ToNeuroCorgiChip

In [None]:
# If possible, set up the GPU 0 for the application
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
train_transforms = transforms.Compose([transforms.Resize((32,32)),
                                       transforms.ToTensor(),
                                       ToNeuroCorgiChip()
                                       ])

test_transforms = transforms.Compose([transforms.Resize((32, 32)),
                                      transforms.ToTensor(),
                                      ToNeuroCorgiChip()
                                      ])

dataset_train = datasets.CIFAR100(root="./data", train=True, download=True, transform=train_transforms)
dataset_test = datasets.CIFAR100(root="./data", train=False, download=True, transform=test_transforms)

train_loader = torch.utils.data.DataLoader(dataset_train, batch_size=128, shuffle=True, drop_last=True)
test_loader = torch.utils.data.DataLoader(dataset_test, batch_size=128, shuffle=False, drop_last=True)

In [None]:
def im_convert(tensor):
    image = tensor.cpu().clone().detach().numpy()
    image = image.transpose(1, 2, 0)
    image = image.clip(0, 1)
    image = image / 255
    return image

In [None]:
classes = train_loader.dataset.classes
print(classes)

In [None]:
dataiter = iter(train_loader)
images, labels = next(dataiter)
fig = plt.figure(figsize=(25, 4))

for idx in np.arange(20):
    ax = fig.add_subplot(2, 10, idx+1, xticks=[], yticks=[])
    plt.imshow(im_convert(images[idx]))
    ax.set_title(classes[labels[idx].item()])

In [None]:
# For this example, we use the model pretrained and quantized with the ImageNet dataset
extractor = NeuroCorgiNet("neurocorginet_imagenet.safetensors")
extractor.to(device)

In [None]:
head = torch.nn.Sequential(
    torch.nn.Flatten(start_dim=1),
    torch.nn.Linear(1024, 100))
head.to(device)
head = head.train()

In [None]:
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(head.parameters(), lr=0.01)

In [None]:
epochs = 15
running_loss_history = []
running_corrects_history = []
val_running_loss_history = []
val_running_corrects_history = []

for e in range(epochs):
    print('epoch :', (e+1))
  
    running_loss = 0.0
    running_corrects = 0.0
    val_running_loss = 0.0
    val_running_corrects = 0.0
  
    for i, (inputs, labels) in enumerate(tqdm(train_loader)):
        inputs = inputs.to(device)
        labels = labels.to(device)

        div4, div8, div16, div32 = extractor(inputs)
        outputs = head(div32 / 15.)

        loss = criterion(outputs, labels)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        _, preds = torch.max(outputs, 1)
        running_loss += loss.item()
        running_corrects += torch.sum(preds == labels.data)

    else:
        with torch.no_grad():
            for i, (val_inputs, val_labels) in enumerate(tqdm(test_loader)):
                val_inputs = val_inputs.to(device)
                val_labels = val_labels.to(device)

                div4, div8, div16, div32 = extractor(val_inputs)
                val_outputs = head(div32 / 15.)

                val_loss = criterion(val_outputs, val_labels)
                
                _, val_preds = torch.max(val_outputs, 1)
                val_running_loss += val_loss.item()
                val_running_corrects += torch.sum(val_preds == val_labels.data)
        
        epoch_loss = running_loss/len(train_loader.dataset)
        epoch_acc = running_corrects.float()/ len(train_loader.dataset)
        running_loss_history.append(epoch_loss)
        running_corrects_history.append(epoch_acc)
        
        val_epoch_loss = val_running_loss/len(test_loader.dataset)
        val_epoch_acc = val_running_corrects.float()/ len(test_loader.dataset)
        val_running_loss_history.append(val_epoch_loss)
        val_running_corrects_history.append(val_epoch_acc)

        print(f'training loss: {epoch_loss:.4f}, acc: {epoch_acc.item():.4f} ')
        print(f'validation loss: {val_epoch_loss:.4f}, acc: {val_epoch_acc.item():.4f} ')

In [None]:
plt.plot(running_loss_history, label='training loss')
plt.plot(val_running_loss_history, label='validation loss')
plt.legend()

In [None]:
running_corrects_history = [x.to("cpu") for x in running_corrects_history]
val_running_corrects_history = [x.to("cpu") for x in val_running_corrects_history]

plt.plot(running_corrects_history, label='training accuracy')
plt.plot(val_running_corrects_history, label='validation accuracy')
plt.legend()

In [None]:
dataiter = iter(test_loader)
images, labels = next(dataiter)
images = images.to(device)
labels = labels.to(device)

_, _, _, div32 = extractor(images)
output = head(div32 / 15.)
_, preds = torch.max(output, 1)

fig = plt.figure(figsize=(25, 4))

for idx in np.arange(20):
    ax = fig.add_subplot(2, 10, idx+1, xticks=[], yticks=[])
    plt.imshow(im_convert(images[idx]))
    ax.set_title("{} ({})".format(str(classes[preds[idx].item()]), 
                                  str(classes[labels[idx].item()])), 
                                  color=("green" if preds[idx]==labels[idx] else "red"))