In [23]:
import torch
import torch.nn as nn
from torchvision.datasets import CIFAR10
import torchvision.transforms as T
import torchvision
import tqdm.notebook as tq
import matplotlib.pyplot as plt
import numpy as np
import pytorchtools as pt
%matplotlib notebook

In [32]:
NUM_WORKERS = 8
BATCH_SIZE = 128
INPUT_SIZE = 128
LR = 1e-4
NUM_EPOCHS = 10
DEVICE = "cuda:0"

# Model

In [33]:
model = torchvision.models.resnet18(pretrained=True)
model.fc = nn.Sequential(nn.Linear(512, 2), nn.Linear(2, 10))
nn.init.xavier_normal_(model.fc[0].weight)
nn.init.xavier_normal_(model.fc[1].weight)

monitor = pt.ForwardMonitor(model, verbose=False)
monitor.add_layer("fc.0")


# Datasets

In [34]:
# Train loader
tr_train = T.Compose([T.Resize((INPUT_SIZE, INPUT_SIZE)),
                      T.RandomHorizontalFlip(),
                      T.ColorJitter(.1,.1,.1),
                      T.ToTensor()])

data_train = CIFAR10(root=".",train=True,transform=tr_train,download=True)
load_train = torch.utils.data.DataLoader(data_train,
                                         num_workers=NUM_WORKERS,
                                         batch_size=BATCH_SIZE,
                                         shuffle=True,
                                         drop_last=True)

Files already downloaded and verified


In [35]:
# test loader
tr_test = T.Compose([T.Resize((INPUT_SIZE, INPUT_SIZE)),
                      T.ToTensor()])

data_test = CIFAR10(root=".",train=False,transform=tr_test)
load_test = torch.utils.data.DataLoader(data_test,
                                         num_workers=NUM_WORKERS,
                                         batch_size=BATCH_SIZE,
                                         shuffle=False)

# Optimization

In [36]:
optim = torch.optim.Adam(model.parameters(), lr=LR)

In [37]:
criterion = nn.CrossEntropyLoss()

# Training

In [38]:
from sklearn.metrics import accuracy_score

In [39]:
def visualize_features(fig, ax, features, labels, title=""):
    ax.clear()
    scatter = ax.scatter(features[:,0],
                  features[:,1],
                  s = 10,
                  c=labels,
                  cmap="Paired")
    ax.legend(handles=scatter.legend_elements()[0],
                 labels=data_train.classes,
                 loc="upper right",
                fontsize=8)
    
    ax.title.set_text(title)

    fig.canvas.draw()
    fig.canvas.flush_events()

In [41]:
fig, ax = plt.subplots(1, 2, figsize=(10,5))
plt.ion()

n_iter_train = len(load_train)
n_iter_test = len(load_test)
model.to(DEVICE)
for epoch in range(NUM_EPOCHS):
    model.train()
    feature_points = []
    label_values = []
    for iter, (images, labels) in enumerate(tq.tqdm(load_train)):
        optim.zero_grad()
        images = images.to(DEVICE)
        labels = labels.to(DEVICE)
        
        out = model(images)

        loss = criterion(out, labels)
        loss.backward()
        
        optim.step()
        
        if iter % int(.1*n_iter_train) == 0:
            features_val = monitor.get_layer("fc.0").detach().squeeze(-1).squeeze(-1).cpu().numpy()
            label_values.append(labels.cpu().numpy())
            feature_points.append(features_val)
            
            visualize_features(fig,
                               ax[0],
                               np.vstack(feature_points),
                               np.concatenate(label_values),
                              "train")

    # test
    features_test = []
    y_pred = []
    y_test = []
    model.eval()
    
    for iter, (images, labels) in enumerate(tq.tqdm(load_test)):
        images = images.to(DEVICE)
        y_test.append(labels.cpu().numpy())
        with torch.no_grad():
            out = model(images)
        
        y_pred.append(torch.argmax(out, 1).cpu().numpy())
        features = monitor.get_layer("fc.0").squeeze(-1).squeeze(-1).cpu().numpy()
        features_test.append(features)
            
        if iter % int(0.1*n_iter_test) == 0:
            visualize_features(fig,
                               ax[1],
                               np.vstack(features_test),
                               np.concatenate(y_pred),
                              "test")
    
    test_acc = accuracy_score(np.concatenate(y_test), np.concatenate(y_pred))
    print(f"Test Accuracy {test_acc*100:.4f}")

<IPython.core.display.Javascript object>

  0%|          | 0/390 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Test Accuracy 67.7100


  0%|          | 0/390 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Test Accuracy 80.3600


  0%|          | 0/390 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Test Accuracy 80.8500


  0%|          | 0/390 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Test Accuracy 81.1400


  0%|          | 0/390 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Test Accuracy 81.4000


  0%|          | 0/390 [00:00<?, ?it/s]

KeyboardInterrupt: 