In [1]:
import torch
from torch import nn
from torch import optim
from torch.utils.data import DataLoader

import torchvision
from torchvision import datasets
from torchvision import transforms
from torchvision.transforms import ToTensor, Lambda, Compose, Normalize

from torchvision import models

import time
import tqdm as tqdm
from torch.autograd import Variable

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('using {} device'.format(device))

using cuda device


In [2]:
transform = Compose([ToTensor(),
                    Lambda(lambda x: x.repeat(3, 1, 1))])

batch_size = 256

train_data = datasets.FashionMNIST(
    root="/media/storage/Datasets",
    train=True,
    download=True,
    transform=transform
)

test_data = datasets.FashionMNIST(
    root="/media/storage/Datasets",
    train=False,
    download=True,
    transform=transform
)

train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size,
    shuffle=True, num_workers=0)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size,
    shuffle=True, num_workers=0)

In [3]:
class ResNetFeatureExtract(nn.Module):
    def __init__(self, pretrained = True):
        super(ResNetFeatureExtract, self).__init__()
        model_resnet18 = models.resnet18(pretrained=pretrained)
        self.conv1 = model_resnet18.conv1
        self.bn1 = model_resnet18.bn1
        self.relu = model_resnet18.relu
        self.maxpool = model_resnet18.maxpool
        self.layer1 = model_resnet18.layer1
        self.layer2 = model_resnet18.layer2
        self.layer3 = model_resnet18.layer3
        self.layer4 = model_resnet18.layer4
        self.avgpool = model_resnet18.avgpool
        
    def forward(self, x):
        x = self.relu(self.bn1(self.conv1(x)))
        x = self.maxpool(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        
        return x

class ResClassifier(nn.Module):
    def __init__(self, dropout=0.5):
        super(ResClassifier, self).__init__()
        self.fc = nn.Linear(512, 10)
    
    def forward(self, x):
        out = self.fc(x)
        return out
    
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find("Conv") != -1:
        torch.nn.init.xavier_uniform_(m.weight)
    elif classname.find('Linear') != -1:
        torch.nn.init.xavier_uniform_(m.weight)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.01)

def test_accuracy(data_iter, netG, netF):
    """Evaluate testset accuracy of a model."""
    acc_sum,n = 0,0
    for (imgs, labels) in data_iter:
        imgs = imgs.to(device)
        labels = labels.to(device)
        netG.eval()
        netF.eval()
        with torch.no_grad():
            labels = labels.long()
            acc_sum += torch.sum((torch.argmax(netF(netG(imgs)), dim=1) == labels)).float()
            n += labels.shape[0]
    return acc_sum.item()/n

In [4]:
netG = ResNetFeatureExtract(pretrained = True)
netF = ResClassifier()

netG = netG.to(device)
netF = netF.to(device)

opt_g = optim.SGD(netG.parameters(), lr=0.01, weight_decay=0.0005)
opt_f = optim.SGD(netF.parameters(), lr=0.01, momentum=0.9, weight_decay=0.0005)

criterion = nn.CrossEntropyLoss()

for epoch in range(0,10):
    n, start = 0, time.time()
    train_l_sum = torch.tensor([0.0], dtype=torch.float32)
    train_acc_sum = torch.tensor([0.0], dtype=torch.float32)
    for i, (imgs, labels) in tqdm.tqdm(enumerate(iter(train_loader))):
        netG.train()
        netF.train()
        imgs = Variable(imgs)
        labels = Variable(labels)
        
        imgs = imgs.to(device)
        labels = labels.to(device)
        train_l_sum = train_l_sum.to(device)
        train_acc_sum = train_acc_sum.to(device)
        
        opt_g.zero_grad()
        opt_f.zero_grad()
        
        bottleneck = netG(imgs)
        
        labels_hat = netF(bottleneck)
        
        loss = criterion(labels_hat, labels)
        loss.backward()
        opt_g.step()
        opt_f.step()
        
        netG.eval()
        netF.eval()
        labels = labels.long()
        train_l_sum += loss.float()
        train_acc_sum += (torch.sum((torch.argmax(labels_hat, dim=1) == labels))).float()
        n += labels.shape[0]
    test_acc = test_accuracy(iter(test_loader), netG, netF)
    print('epoch %d, perdida %.4f, precisión train %.3f, precisión test  %.3f, tiempo %.1f seg'\
        % (epoch + 1, train_l_sum/n, train_acc_sum/n, test_acc, time.time() - start))

235it [00:09, 23.62it/s]


epoch 1, loss 0.0020, train acc 0.822, test acc 0.875, time 10.9 sec


235it [00:09, 24.42it/s]


epoch 2, loss 0.0011, train acc 0.895, test acc 0.883, time 10.6 sec


235it [00:09, 24.29it/s]


epoch 3, loss 0.0009, train acc 0.915, test acc 0.885, time 10.7 sec


235it [00:09, 24.56it/s]


epoch 4, loss 0.0008, train acc 0.928, test acc 0.893, time 10.6 sec


235it [00:09, 24.47it/s]


epoch 5, loss 0.0006, train acc 0.939, test acc 0.893, time 10.7 sec


235it [00:09, 24.40it/s]


epoch 6, loss 0.0005, train acc 0.948, test acc 0.896, time 10.7 sec


235it [00:09, 24.29it/s]


epoch 7, loss 0.0005, train acc 0.955, test acc 0.895, time 10.7 sec


235it [00:09, 24.13it/s]


epoch 8, loss 0.0004, train acc 0.965, test acc 0.887, time 10.9 sec


235it [00:09, 24.30it/s]


epoch 9, loss 0.0003, train acc 0.969, test acc 0.884, time 10.8 sec


235it [00:09, 24.17it/s]


epoch 10, loss 0.0003, train acc 0.974, test acc 0.896, time 10.8 sec
