In [3]:
from google.colab import drive
drive.mount('/content/gdrive')
%cd gdrive/My Drive/Colab Notebooks/mixup

Go to this URL in a browser: https://accounts.google.com/o/oauth2/auth?client_id=947318989803-6bn6qk8qdgf4n4g3pfee6491hc0brc4i.apps.googleusercontent.com&redirect_uri=urn%3aietf%3awg%3aoauth%3a2.0%3aoob&scope=email%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdocs.test%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive.photos.readonly%20https%3a%2f%2fwww.googleapis.com%2fauth%2fpeopleapi.readonly&response_type=code

Enter your authorization code:
··········
Mounted at /content/gdrive
/content/gdrive/My Drive/Colab Notebooks/mixup


In [4]:
import torch
import torch.nn as nn
from torch.nn.functional import one_hot
import torch.optim as optim

import torchvision
import torchvision.transforms as transforms

import matplotlib.pyplot as plt
import numpy as np
import os

from tqdm import tqdm

from utils import progress_bar
from resnet import ResNet18

In [5]:
device = torch.device('cuda:0')
torch.cuda.get_device_name()

'Tesla T4'

In [6]:
transform_train = transforms.Compose([transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), 
                                transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])
transform_test = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
train_loader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True)

testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
test_loader = torch.utils.data.DataLoader(testset, batch_size=128, shuffle=False)

classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

Files already downloaded and verified
Files already downloaded and verified


In [7]:
def imshow(img):
    img[0] = img[0] * 0.2023 + 0.4914
    img[1] = img[1] * 0.1994 + 0.4822
    img[2] = img[2] * 0.2010 + 0.4465
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()

In [8]:
def mixup(batch, label, alpha=1.0):
    lam = np.random.beta(alpha, alpha)
    index = torch.randperm(batch.size()[0])
    new_batch = batch[index]
    new_label = label[index]
    mixup_batch = lam * batch + (1 - lam) * new_batch
    return mixup_batch, (label, new_label), lam

def mix_label(label, new_label, lam, output, loss_function):
    loss = loss_function(output, label)
    new_loss = loss_function(output, new_label)
    mix_loss = lam * loss + (1 - lam) * new_loss
    return mix_loss

In [9]:
def train(model, mix=False):
    model.train()
    train_loss = 0
    total = 0
    correct = 0
    for i, (batch, label) in enumerate(train_loader):
        batch, label = batch.to(device), label.to(device)
        
        if mix == True:
            batch, (label, new_label), lam = mixup(batch, label)

        output = model(batch)
        optimizer.zero_grad()
        loss = loss_function(output, label) if mix == False else mix_label(label, new_label, lam, output, loss_function)
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item()
        _, predicted = output.max(1)
        total += label.size(0)
        correct += predicted.eq(label).sum().item()

        acc = 100. * correct / total
        progress_bar(i, len(train_loader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'%(train_loss/(i + 1), acc, correct, total))

In [10]:
def test(model):
    model.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for i, (batch, label) in enumerate(test_loader):
            batch, label = batch.to(device), label.to(device)
            output = model(batch)
            loss = loss_function(output, label)

            test_loss += loss.item()
            _, predicted = output.max(1)
            total += label.size(0)
            correct += predicted.eq(label).sum().item()
            
            acc = 100. * correct / total
            progress_bar(i, len(test_loader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'%(test_loss/(i + 1), acc, correct, total))
    return acc

In [11]:
def save_model(epoch, acc):
    global best_acc
    if acc > best_acc:
        print('Saving..')
        state = {
            'model': model.state_dict(),
            'acc': acc,
            'epoch': epoch
        }
        if not os.path.isdir('checkpoint'):
            os.mkdir('checkpoint')
        torch.save(state, f'./checkpoint/model_{epoch}.pth')
        best_acc = acc

def load_model(name):
    model = ResNet18()
    model = torch.load(f'./checkpoint/{name}.pth')
    return model['model']

In [8]:
def decaying_learning_rate(optimizer, epoch):
    lr = 1e-1
    if epoch >= 100:
        lr /= 10
    if epoch >= 150:
        lr /= 10
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

In [14]:
model = ResNet18()
model.to(device)
loss_function = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=1e-1, momentum=0.9, weight_decay=1e-4)
best_acc = 0
for epoch in range(200):
    print(epoch)
    train(model, True)
    test_acc = test(model)
    save_model(epoch, test_acc)
    decaying_learning_rate(optimizer, epoch)

0
Saving..
1
Saving..
2
Saving..
3
Saving..
4
Saving..
5
6
Saving..
7
Saving..
8
9
Saving..
10
11
12
Saving..
13
Saving..
14
Saving..
15
16
Saving..
17
18
Saving..
19
Saving..
20
21
Saving..
22
23
24
Saving..
25
26
27
28
29
30
Saving..
31
32
33
34
35
36
Saving..
37
38
39
40
Saving..
41
42
43
44
Saving..
45
46
47
Saving..
48
49
50
51
52
53
54
Saving..
55
56
57
58
59
60
61
62
Saving..
63
64
65
66
67
68
69
70
71
72
73
74
75
Saving..
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
Saving..
102
Saving..
103
104
105
Saving..
106
107
Saving..
108
109
110
111
112
Saving..
113
114
115
116
117
118
Saving..
119
120
Saving..
121
122
123
124
125
126
127
Saving..
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
Saving..
145
146
147
148
149
150
151
152
153
154
155
156
Saving..
157
158
Saving..
159
160
161
162
Saving..
163
164
165
Saving..
166
167
168
169
170
171
172
173
174
175
176
Saving..
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191


In [14]:
model = ResNet18()
model.to(device)
loss_function = nn.CrossEntropyLoss()
model.load_state_dict(load_model('mixup'))
test(model)



96.06