In [7]:
import os
import torch
from torch import nn
import torch.nn.functional as F
from torchvision import transforms
from torchvision.datasets import CIFAR10

from torch.utils.data import DataLoader, random_split
import pytorch_lightning as pl
from tqdm import tqdm
import torchvision.models as models

In [8]:
class Model(pl.LightningModule):
    def __init__(self):
        super(Model, self).__init__()
        self.resnet18 = models.resnet18(pretrained=True)
        self.resnet18.fc = nn.Linear(512, 10)
        #self.output = torch.nn.Linear(1000, 10)
        
        self.train_acc = pl.metrics.Accuracy()
        self.val_acc = pl.metrics.Accuracy()
        self.test_acc = pl.metrics.Accuracy()

    def forward(self, x):
        y = self.resnet18(x)
        #y = torch.relu(y)
        #y = self.output(y)
        y = F.softmax(y,dim=1)
        return y

    def training_step(self, batch, batch_nb):
        x, y = batch
        loss = F.cross_entropy(self(x), y)
        self.log("loss", loss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
        return loss
    
    def validation_step(self, batch, batch_nb): 
        x, t = batch
        y = self(x)
        loss = F.cross_entropy(y, t)
        preds = torch.argmax(y, dim=1)

        self.log('val_loss', loss, prog_bar=True)
        self.log('val_acc', self.val_acc(y,t), prog_bar=True)
        return loss

    def test_step(self, batch, batch_nb):
        return self.validation_step(batch, batch_nb)
    
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.02)

In [9]:
transform = transforms.Compose([transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor()])
train_dataset = CIFAR10(".", train=True, download=True,transform=transform)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./cifar-10-python.tar.gz


  0%|          | 0/170498071 [00:00<?, ?it/s]

Extracting ./cifar-10-python.tar.gz to .


In [10]:
train_loader = DataLoader(train_dataset, batch_size=512)

In [19]:
model = Model()
#for param in model.resnet18.parameters():
#    param.requires_grad = False
trainer = pl.Trainer(gpus=1,max_epochs=1,)
trainer

GPU available: True, used: True
TPU available: None, using: 0 TPU cores


AttributeError: 'Trainer' object has no attribute 'validation'

In [12]:
trainer.fit(model, train_loader) 


  | Name      | Type     | Params
---------------------------------------
0 | resnet18  | ResNet   | 11.2 M
1 | train_acc | Accuracy | 0     
2 | val_acc   | Accuracy | 0     
3 | test_acc  | Accuracy | 0     
---------------------------------------
11.2 M    Trainable params
0         Non-trainable params
11.2 M    Total params
44.727    Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]



Training: 0it [00:00, ?it/s]

1

In [15]:
test_dataset = CIFAR10(".", train=False, download=True, transform=transform)

Files already downloaded and verified


In [16]:
test_loader = DataLoader(test_dataset, batch_size=32)

In [17]:
trainer.test(model,test_dataloaders=test_loader)



Testing: 0it [00:00, ?it/s]

--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'val_acc': 0.16859999299049377, 'val_loss': 2.289808750152588}
--------------------------------------------------------------------------------


[{'val_loss': 2.289808750152588, 'val_acc': 0.16859999299049377}]

In [None]:
dummy_input = torch.randn(10, 3, 224, 224, device='cuda')
torch.onnx.export(model, dummy_input, "model.onnx", verbose=True, input_names=["input"], output_names=["output"])

In [14]:
model.summarize

<bound method LightningModule.summarize of Model(
  (resnet18): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1