In [87]:
import sys
import os

import scipy.misc
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.nn import DataParallel
from torch.nn.parallel import DistributedDataParallel

from tqdm.notebook import tqdm

import torchvision
from torchvision import datasets as torchsets
import matplotlib.pyplot as plt
%matplotlib

Using matplotlib backend: Qt5Agg


In [24]:
if sys.version_info[0] < 3:
    raise Excpetion('You must use Python 3 or higher.')

In [76]:
class DatasetRepeater(Dataset):
    def __init__(self, dataset, num_repeats=100):
        self.dataset = dataset
        self.num_repeats = num_repeats
        
    def __len__(self):
        return self.num_repeats*self.dataset.__len__()
    
    def __getitem__(self, idx):
        return self.dataset[idx % self.dataset.__len__()]

In [80]:
mnist_train = torchsets.MNIST(root='./data', train=True, download=True, transform=torchvision.transforms.ToTensor())
mnist_test = torchsets.MNIST(root='./data', train=False, download=True, transform=torchvision.transforms.ToTensor())

mnist_train = DatasetRepeater(mnist_train)

print(len(mnist_train))
print(mnist_test.test_data.shape)

6000000
torch.Size([10000, 28, 28])


In [54]:
class Tnet(nn.Module):
    def __init__(self):
        super(Tnet, self).__init__()
        self.conv0 = nn.Conv2d(1, 1, kernel_size=3, padding=1)
        self.conv1 = nn.Conv2d(1, 1, kernel_size=3, padding=1)
        self.fc = nn.Linear(16, 10)
        self.softmax = nn.Softmax(dim=1)        
        
    def forward(self, x):
        x = F.relu(self.conv0(x))
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, kernel_size=7)
        x = x.view(-1, 16)
        x = self.fc(x)
        return x

In [81]:
train_loader = DataLoader(mnist_train, batch_size=10, shuffle=True, num_workers=1)
test_loader = DataLoader(mnist_test, batch_size=10, shuffle=True, num_workers=1)

In [82]:
tnet = Tnet()
criterion = nn.CrossEntropyLoss()

In [89]:
if torch.cuda.is_available():
    tnet = DataParallel(tnet, device_ids=list(range(torch.cuda.device_count())))
tnet.train()

optimizer = torch.optim.Adam(tnet.parameters(), lr=0.001, betas=(0.5, 0.999))

for epoch in range(100):
    total_loss = 0
    with tqdm(total=len(train_loader)) as tqdm_bar:
        for i, (x, y) in enumerate(train_loader):
            x = x.cuda()
            y = y.cuda()
            out = tnet(x)
            loss = criterion(out, y)
            total_loss += loss

            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            
            if loss < 0.1:
                tqdm_bar.set_description('loss: {:f}'.format(loss))
            tqdm_bar.update(1)
        print(epoch, total_loss/len(train_loader))


HBox(children=(FloatProgress(value=0.0, max=600000.0), HTML(value='')))




KeyboardInterrupt: 