In [6]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim


In [7]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [8]:
import sys
sys.path.append('/content/drive/MyDrive/Colab Notebooks')

from ResNet import ResNet, ResNet101

In [9]:
transform_train = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(32, padding=4),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

In [10]:
train = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)

trainloader = torch.utils.data.DataLoader(train, batch_size=128, shuffle=True, num_workers=2)

test = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)

testloader = torch.utils.data.DataLoader(test, batch_size=128,shuffle=False, num_workers=2)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:01<00:00, 101430070.20it/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified


In [11]:
classes = ['plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']

In [12]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
net = ResNet101(3).to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.1, momentum=0.9, weight_decay=0.0001)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor = 0.1, patience=5)

In [13]:
EPOCHS = 200
for epoch in range(EPOCHS):
    losses = []
    running_loss = 0
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    for i, inp in enumerate(trainloader):
        inputs, labels = inp
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()

        outputs = net(inputs)
        loss = criterion(outputs, labels)
        losses.append(loss.item())

        loss.backward()
        optimizer.step()

        running_loss += loss.item()

        if i%100 == 0 and i > 0:
            print(f'Loss [{epoch+1}, {i}](epoch, minibatch): ', running_loss / 100)
            running_loss = 0.0

    avg_loss = sum(losses)/len(losses)
    scheduler.step(avg_loss)

print('Training Done')

Loss [1, 100](epoch, minibatch):  6.9983795499801635
Loss [1, 200](epoch, minibatch):  3.43063227891922
Loss [1, 300](epoch, minibatch):  2.6868314385414123
Loss [2, 100](epoch, minibatch):  2.34351522564888
Loss [2, 200](epoch, minibatch):  2.189417153596878
Loss [2, 300](epoch, minibatch):  2.028717914819717
Loss [3, 100](epoch, minibatch):  1.9946791803836823
Loss [3, 200](epoch, minibatch):  1.943064739704132
Loss [3, 300](epoch, minibatch):  1.9243985247611999
Loss [4, 100](epoch, minibatch):  1.9229141187667846
Loss [4, 200](epoch, minibatch):  1.8872440791130065
Loss [4, 300](epoch, minibatch):  1.8746490454673768
Loss [5, 100](epoch, minibatch):  1.8788133943080902
Loss [5, 200](epoch, minibatch):  1.8546157193183899
Loss [5, 300](epoch, minibatch):  1.8262171459197998
Loss [6, 100](epoch, minibatch):  1.8502129077911378
Loss [6, 200](epoch, minibatch):  1.8054807829856871
Loss [6, 300](epoch, minibatch):  1.7503284955024718
Loss [7, 100](epoch, minibatch):  1.7242645370960235


In [14]:
correct = 0
total = 0

with torch.no_grad():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    for data in testloader:
        images, labels = data
        images, labels = images.to(device), labels.to(device)
        outputs = net(images)

        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
print('Accuracy on 10,000 test images: ', 100*(correct/total), '%')

Accuracy on 10,000 test images:  85.9 %
