In [49]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F

In [50]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

# Assuming that we are on a CUDA machine, this should print a CUDA device:
print(device)

cuda:0


In [51]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])

batch_size = 512
trainset, validset = torch.utils.data.random_split(
    torchvision.datasets.CIFAR100(root='./data', train=True, download=True, transform = transform),
    lengths=[40000, 10000])

trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)
validloader = torch.utils.data.DataLoader(validset, batch_size=batch_size, shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR100(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2)

classes = ('skunk', 'caterpillar', 'aquarium_fish', 'house', 'porcupine', 'castle', 'bear', 
           'cloud', 'snake', 'rabbit', 'fox', 'maple_tree', 'hamster', 'bed', 'cattle', 'chair', 
           'seal', 'squirrel', 'couch', 'bottle', 'pine_tree', 'tank', 'keyboard', 'raccoon', 
           'cockroach', 'pickup_truck', 'orange', 'spider', 'beetle', 'streetcar', 'dinosaur', 
           'tulip', 'wardrobe', 'lion', 'train', 'clock', 'crab', 'palm_tree', 'rocket', 'forest', 
           'mouse', 'shrew', 'chimpanzee', 'sweet_pepper', 'tiger', 'telephone', 'road', 'bicycle', 
           'woman', 'butterfly', 'wolf', 'lizard', 'bus', 'mountain', 'poppy', 'beaver', 'trout', 
           'man', 'bridge', 'pear', 'plain', 'table', 'ray', 'whale', 'motorcycle', 'crocodile', 
           'dolphin', 'snail', 'plate', 'cup', 'oak_tree', 'girl', 'apple', 'baby', 'television', 
           'skyscraper', 'flatfish', 'kangaroo', 'turtle', 'worm', 'elephant', 'lobster', 'tractor', 
           'bee', 'willow_tree', 'rose', 'can', 'sunflower', 'possum', 'sea', 'bowl', 'otter', 'shark', 
           'camel', 'lamp', 'lawn_mower', 'mushroom', 'orchid', 'boy', 'leopard')

Files already downloaded and verified
Files already downloaded and verified


 #### Define a Convolutional Neural Network

In [52]:
class BasicBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)

        self.downsample = None
        if stride != 1 or in_channels != out_channels:
            self.downsample = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)
        
        return out

class ResNet18(nn.Module):
    def __init__(self, num_classes=100):
        super(ResNet18, self).__init__()
        self.in_channels = 64

        self.conv1 = nn.Conv2d(3, self.in_channels, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(self.in_channels)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        self.layer1 = self._make_layer(BasicBlock, 64, num_blocks=2, stride=1)
        self.layer2 = self._make_layer(BasicBlock, 128, num_blocks=2, stride=2)
        self.layer3 = self._make_layer(BasicBlock, 256, num_blocks=2, stride=2)
        self.layer4 = self._make_layer(BasicBlock, 512, num_blocks=2, stride=2)

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.dropout = nn.Dropout(0.5)
        self.fc = nn.Linear(512, num_classes)

    def _make_layer(self, block, out_channels, num_blocks, stride):
        layers = []
        layers.append(block(self.in_channels, out_channels, stride))
        self.in_channels = out_channels

        for _ in range(1, num_blocks):
            layers.append(block(out_channels, out_channels))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)

        return x

#### Define a Loss function and optimizer

In [53]:
import torch.optim as optim
net = ResNet18().to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

#### Train the network

In [54]:
from tqdm.notebook import tqdm
# progress_bar = tqdm(enumerate(trainloader, 0), total=len(trainloader), desc='Training')

for epoch in range(50):  # epoch
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data[0].to(device), data[1].to(device)
        net.train()
        # zero gradient
        optimizer.zero_grad()
        
        # forward, backward, optimize
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        #wandb.log({"loss": loss})
        #progress_bar.update(1)
        running_loss += loss.item()
        
        if i % 100 == 0:
            # validation
            net.eval()
            correct = 0
            total = 0
            with torch.no_grad():
                for data in validloader: 
                    images, labels = data[0].to(device), data[1].to(device)
                    outputs = net(images)
                    # the class with the highest energy is what we choose as prediction
                    _, predicted = torch.max(outputs.data, 1)
                    total += labels.size(0)
                    correct += (predicted == labels).sum().item()
            
            print("Epoch : " + "{:<3}".format(epoch) + " |  Accuracy : " + "{:<4}".format(100. * correct / total) + "%")
        running_loss = 0.0


print('Finished Training')


Epoch : 0   |  Accuracy :  1.07 %
Epoch : 0   |  Accuracy :  10.16 %
Epoch : 1   |  Accuracy :  13.85 %
Epoch : 1   |  Accuracy :  17.27 %
Epoch : 2   |  Accuracy :  18.89 %
Epoch : 2   |  Accuracy :  21.29 %
Epoch : 3   |  Accuracy :  21.96 %
Epoch : 3   |  Accuracy :  22.59 %
Epoch : 4   |  Accuracy :  23.48 %
Epoch : 4   |  Accuracy :  24.64 %
Epoch : 5   |  Accuracy :  25.1 %
Epoch : 5   |  Accuracy :  25.17 %
Epoch : 6   |  Accuracy :  25.33 %
Epoch : 6   |  Accuracy :  26.16 %
Epoch : 7   |  Accuracy :  26.26 %
Epoch : 7   |  Accuracy :  25.82 %
Epoch : 8   |  Accuracy :  25.47 %
Epoch : 8   |  Accuracy :  26.2 %
Epoch : 9   |  Accuracy :  26.07 %
Epoch : 9   |  Accuracy :  26.01 %
Epoch : 10  |  Accuracy :  25.87 %
Epoch : 10  |  Accuracy :  26.19 %
Epoch : 11  |  Accuracy :  25.81 %
Epoch : 11  |  Accuracy :  26.21 %
Epoch : 12  |  Accuracy :  25.27 %
Epoch : 12  |  Accuracy :  26.25 %
Epoch : 13  |  Accuracy :  26.32 %
Epoch : 13  |  Accuracy :  26.62 %
Epoch : 14  |  Accuracy

#### Testing

In [55]:
correct = 0
total = 0
# since we're not training, we don't need to calculate the gradients for our outputs
with torch.no_grad():
    for data in testloader: 
        images, labels = data[0].to(device), data[1].to(device)
        outputs = net(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f'Test Accuracy: {100 * correct / total} %')

Test Accuracy: 26.19 %


#### Save Model

In [57]:
torch.save(net, './data/model_train.pt')
print('Model Saved')

Model Saved
