In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms

import nets.resnet as resnet
import mi.estim as estim
import mi.critics as critics
import utils.data as data
import models.cpc as cpc

In [2]:
gpu_device = 1

if gpu_device >= 0 and torch.cuda.is_available():
    device = torch.device('cuda:{}'.format(gpu_device))
    torch.cuda.set_device(device)
else:
    device = torch.device('cpu')
    torch.cuda.set_device(device)

In [3]:
batch_size = 256
patch_size = 12
stride = 4
p1, p2 = 3, 6

trans = []
trans.append(transforms.RandomGrayscale(.5))
trans.append(transforms.RandomHorizontalFlip(.5))
trans.append(transforms.RandomResizedCrop((32, 32), scale=(0.8, 1.0)))
augment = transforms.Compose(trans)

num_patche = (32 - patch_size) // stride + 1
trainloader, testloader = data.cifar10('../data/', batch_size, patch_size, stride, augment)

encoder = cpc.Encoder(resnet.resnet56(patch_size))
context = cpc.RNNContext(64, 64)
critics1 = {str(p): critics.BiLinearCritic(64, 64) for p in range(p1, p2)}
critics2 = {str(p): critics.BiLinearCritic(64, 64) for p in range(p1, p2)}
model = cpc.CPC(encoder, context, critics1, critics2).to(device)

model.load_state_dict(torch.load('../saved_models/cpc.chkpt'))

Files already downloaded and verified
Files already downloaded and verified


IncompatibleKeys(missing_keys=[], unexpected_keys=[])

In [4]:
log_interval = 10
encoder.eval()
classifier = nn.Linear(64, 10).to(device)
sgd = optim.Adam(classifier.parameters())
ce = nn.CrossEntropyLoss()

for epoch in range(100):
    print('EPOCH {}:'.format(epoch + 1))
    classifier.train()
    running_loss = 0
    for i, (x, y) in enumerate(trainloader):
        x, y = x.to(device), y.to(device)
        sgd.zero_grad()
        with torch.no_grad():
            x = encoder(x)
            x = x.flatten(start_dim=1, end_dim=2)
            x = torch.mean(x, 1)
        y_h = classifier(x)
        loss = ce(y_h, y)
        loss.backward()
        sgd.step()
        running_loss += loss.item()

        if i % log_interval == (log_interval - 1):
            print('\titeration {}: loss = {}'.format(i + 1, running_loss/log_interval))
            running_loss = 0

    if epoch % 5 == 4:
        correct = 0
        total = 0
        classifier.eval()
        with torch.no_grad():
            for (x, y) in testloader:
                x, y = x.to(device), y.to(device)
                x = encoder(x)
                x = x.flatten(start_dim=1, end_dim=2)
                x = torch.mean(x, 1)
                y_h = classifier(x)
                _, predicted = torch.max(y_h.data, 1)
                total += y.size(0)
                correct += (predicted == y).sum().item()

        print('Test Accuracy: {}'.format(100*correct/total))


    torch.save(classifier.state_dict(), '../saved_models/clf.chkpt')

correct = 0
total = 0
classifier.eval()
with torch.no_grad():
    for (x, y) in testloader:
        x, y = x.to(device), y.to(device)
        x = encoder(x)
        x = x.flatten(start_dim=1, end_dim=2)
        x = torch.mean(x, 1)
        y_h = classifier(x)
        _, predicted = torch.max(y_h.data, 1)
        total += y.size(0)
        correct += (predicted == y).sum().item()

print('Test Accuracy: {}'.format(100*correct/total))


EPOCH 1:
	iteration 10: loss = 3.596562075614929
	iteration 20: loss = 2.5035330772399904
	iteration 30: loss = 2.4364827632904054
	iteration 40: loss = 2.3767350196838377
	iteration 50: loss = 2.32626416683197
	iteration 60: loss = 2.318194556236267
	iteration 70: loss = 2.274809455871582
	iteration 80: loss = 2.2608928203582765
	iteration 90: loss = 2.2697301387786863
	iteration 100: loss = 2.2222812175750732
	iteration 110: loss = 2.235800099372864
	iteration 120: loss = 2.212452244758606
	iteration 130: loss = 2.1899992704391478
	iteration 140: loss = 2.159664177894592
	iteration 150: loss = 2.152384614944458
	iteration 160: loss = 2.154388189315796
	iteration 170: loss = 2.1285857439041136
	iteration 180: loss = 2.120832824707031
	iteration 190: loss = 2.1026601076126097
EPOCH 2:
	iteration 10: loss = 2.098582363128662
	iteration 20: loss = 2.0635650157928467
	iteration 30: loss = 2.0545152902603148
	iteration 40: loss = 2.056039834022522
	iteration 50: loss = 2.033578562736511
	i

In [None]:
import os
os._exit(0)