In [1]:
from torchvision import models
import torch

resnet50 = models.resnet50(pretrained=True)

Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /home/ryanz/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth
100.0%


In [None]:
"""Replacing the last Fully connected layers with out own"""
resnet50.fc = torch.nn.Linear(resnet50.fc.in_features, 10)

In [2]:
"""Loading the data and nomalizing it"""
from torchvision.transforms.transforms import Normalize
from torchvision.datasets import CIFAR10
from torchvision import transforms

normalize = transforms.Normalize(mean=[x / 255.0 for x in [125.3, 123.0, 113.9]],
                                 std=[x / 255.0 for x in [63.0, 62.1, 66.7]])

cf10_transforms = transforms.Compose([
    transforms.ToTensor(),
    normalize
])

cifar_10 = CIFAR10('.', train=True, download=True, transform=cf10_transforms)

0.0%

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./cifar-10-python.tar.gz


1.9%


KeyboardInterrupt: 

In [None]:
"""Visualize the image using matplot lib"""
from matplotlib import pyplot as plt
image, label = next(iter(cifar_10))

print('LABEL:', label)

plt_img = image.numpy().transpose(1, 2, 0)
plt.imshow(plt_img)

In [None]:
"""Use a data loader"""
from torch.utils.data import DataLoader

train_loader = DataLoader(cifar_10, batch_size=32, shuffle=True)

In [None]:
"""Creating batches"""

for batch in train_loader:
  x, y = batch
  print(x.shape, y.shape)
  break

In [None]:
"""Doing some predictions"""
x, y = next(iter(train_loader))

preds = resnet50(x)
preds[0:10]

In [None]:
"""Turn the prediciton into probabilities"""
from torch.nn.functional import softmax

preds = softmax(preds, dim=-1)
preds[0:10]

In [None]:
"""Creating the labels"""
pred_labels = torch.argmax(preds[0:10], dim=-1)
pred_labels

In [None]:
"""Comparing the predictions to the actual (it is trash because we removed the classifying layers in the back of the resnet"""
y[0:10]

In [None]:
"""Create backbone"""
backbone = models.resnet50(pretrained=True)

In [None]:
finetune_layer = torch.nn.Linear(backbone.fc.out_features, 10)

In [None]:
from pl_bolts.datamodules import CIFAR10DataModule

dm = CIFAR10DataModule('.')

In [None]:
from torch.nn.functional import cross_entropy
from torch.optim import Adam

dm.prepare_data()
dm.setup()

optimizer = Adam(finetune_layer.parameters(), lr=1e-4)
for epoch in range(10):
  for batch in dm.train_dataloader():
    x, y = batch
    with torch.no_grad():
      # (b, 3, 32, 32) -> (b, 1000)
      features = backbone(x)

    # (b, 1000) -> (b, 100)
    preds = finetune_layer(features)
    loss = cross_entropy(preds, y)

    loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    print(loss.item())

In [None]:
"""Turn to pytorch Lightning"""
import pytorch_lightning as pl

class ImageClassifier(pl.LightningModule):
  def __intit__(self, num_classes=10, lr=1e-3):
    super().__init__()
    self.save_hyperparameters()
 
    self.backbone = models.resnet50(pretrained=True)
    self.finetune_layer = torch.nn.Linear(backbone.fc.out_features, num_classes)
    #^ I think this should be like this
    # self.finetune_layer = torch.nn.Linear(backbone.fc.out_features, num_classes=self.hparams.num_classes)

  def training_step(self, batch, batch_idx):
    x, y = batch
    with torch.no_grad():
      # (b, 3, 32, 32) -> (b, 1000)
      features = self.backbone(x)

    # (b, 1000) -> (b, 100)
    preds = self.finetune_layer(features)
    loss = cross_entropy(preds, y)
    return loss

  def configure_optimizers(self):
    optimizer = Adam(self.parameters(), lr=self.hparams.lr)
    return optimizer


In [None]:
classifier = ImageClassifier()

trainer = pl.Trainer(progress_bar_refresh_rate=20)
trainer.fit(classifier, dm)
