In [2]:
import torch
import torchvision
import torch.optim as optim
import torch.nn.functional as F
import torch.nn as nn
from torchvision import models
import pandas as pd
import numpy as np
import os
try:
  import pytorch_lightning as pl
except:
  !pip install pytorch_lightning
import torchvision.transforms as transforms

In [3]:
import sys
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [4]:
data = pd.read_csv(r'/content/drive/My Drive/AML/cassava-leaf-disease-classification/train.csv')
data.head()

Unnamed: 0,image_id,label
0,1000015157.jpg,0
1,1000201771.jpg,3
2,100042118.jpg,1
3,1000723321.jpg,1
4,1000812911.jpg,3


In [5]:
from torch.utils.data import Dataset, DataLoader
from PIL import Image

class CassavaDataset(Dataset):
    """ Cassava Dataset """
    
    def __init__(self, root_dir, transform=None, stage=None):
        if (stage):
            # We're in test stage then
            csv_output = pd.read_csv(os.path.join(root_dir, "sample_submission.csv"))
            self.images_dir = os.path.join(root_dir, "test_images")
        else:
            csv_output = pd.read_csv(os.path.join(root_dir, "train.csv"))
            self.images_dir = os.path.join(root_dir, "train_images")
        self.image_urls = np.asarray(csv_output["image_id"])
        self.labels = np.asarray(csv_output["label"])
        self.transform = transform
        
    def __len__(self):
        return len(self.image_urls)
    
    def __getitem__(self, idx):
        # Get and load image
        image_path = os.path.join(self.images_dir, self.image_urls[idx])
        image = Image.open(image_path)
        # Perform transforms if any
        if self.transform:
            image = self.transform(image)
        # Get label
        label = self.labels[idx]
        return image, label

In [6]:
from torch.utils.data import random_split
import math
class CassavaDataModule(pl.LightningDataModule):
    """ Cassava DataModule for Lightning """
    def __init__(self, root_dir, transform=None, batch_size=32):
        super().__init__()
        self.batch_size = batch_size
        self.root_dir = root_dir
        self.transform = transform
        
    def setup(self, stage=None):
        cassava_full = CassavaDataset(self.root_dir, self.transform)
        train_data_len = math.floor(len(cassava_full) * 0.7)
        val_data_len = len(cassava_full) - train_data_len
        # Create train and validation datasets
        self.cassava_train, self.cassava_val = random_split(cassava_full, [train_data_len, val_data_len], generator=torch.Generator().manual_seed(42))
        
        # Create test dataset
        self.cassava_test = CassavaDataset(self.root_dir, self.transform, stage="test")
        
    def train_dataloader(self):
        return DataLoader(self.cassava_train, batch_size=self.batch_size)
    
    def val_dataloader(self):
        return DataLoader(self.cassava_val, batch_size=self.batch_size)
    
    def test_dataloader(self):
        return DataLoader(self.cassava_test, batch_size=self.batch_size)

In [7]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # Standard Normalization
#      transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) #ImageNet Normalization
    ])

root_dir = r'/content/drive/My Drive/AML/cassava-leaf-disease-classification/'
cassava_data = CassavaDataModule(root_dir, transform, batch_size=4)
cassava_data.setup()

train_loader = cassava_data.train_dataloader()
val_loader = cassava_data.train_dataloader()
test_loader = cassava_data.train_dataloader()

In [9]:
import torch
from torch import nn
from torch import optim

try:
  import torchbearer
except:
    !pip install torchbearer
from torchbearer import Trial
import torchvision


class ResNet(nn.Module):
    def __init__(self,blocks,num_classes=5,expansion=4):
        super(ResNet,self).__init__()
        self.expansion = expansion
        self.conv1 = Conv1(in_planes=3,places=64)
        
        self.layer1 = self.make_layer(in_places=64,places=64,block=blocks[0],stride=1)
        self.layer2 = self.make_layer(in_places=256,places=128,block=blocks[1],stride=2)
        self.layer3 = self.make_layer(in_places=512,places=256,block=blocks[2],stride=2)
        self.layer4 = self.make_layer(in_places=1024,places=512,block=blocks[3],stride=2)
    
        self.avgpool = nn.AvgPool2d(7,stride=1)
        self.fc = nn.Linear(2048,num_classes)
        
        for m in self.modules():
            if isinstance(m,nn.Conv2d):
                nn.init.kaiming_normal_(m.weight,mode='fan_out',nonlinearity='relu')
            elif isinstance(m,nn.BatchNorm2d):
                nn.init.constant_(m.weight,1)
                nn.init.constant_(m.bias,0)
                
    def make_layer(self,in_places,places,block,stride):
        layers = []
        layers.append(Bottleneck(in_places,places,stride,downsampling=True))
        for i in range(1,block):
            layers.append(Bottleneck(places*self.expansion,places))
        return nn.Sequential(*layers)
    
    def forward(self,x):
        x = self.conv1(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        
        x = self.avgpool(x)
        x = x.view(x.size(0),-1)
        x = self.fc(x)
        return x 
loss_function = nn.CrossEntropyLoss()

def ResNet152():
    return ResNet([3,8,36,3])

In [10]:
model = ResNet152()
model.train()
optimizer = torch.optim.SGD(model.parameters(), 0.01)
scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer, base_lr=0.005, max_lr=0.01)

In [11]:
device = "cuda:0" if torch.cuda.is_available() else "cpu"
trial = Trial(model, optimizer, loss_function, metrics=['loss', 'accuracy']).to(device)
trial.with_generators(train_loader, val_loader, test_generator=test_loader)
trial.run(epochs=30)
results = trial.evaluate(data_key=torchbearer.TEST_DATA)
print(results)

RuntimeError: ignored