In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms

from tqdm import tqdm, trange

base_path = '..'

In [2]:
device = torch.device('cuda')

In [3]:
transform = transforms.Compose([
    transforms.ToTensor(),
    lambda x: torch.repeat_interleave(x,3,0),
    transforms.Resize((32,32)),
    transforms.RandomRotation(10), # data augmentation
])
# only 4 kernel
dataloader = torch.utils.data.DataLoader(torchvision.datasets.MNIST(base_path + '/data', transform = transform, download=True), 1024, True, num_workers = 4)
dataloader_test = torch.utils.data.DataLoader(torchvision.datasets.MNIST(base_path + '/data', False, transform = transform, download=True), 32, False)

In [4]:
f = lambda x: torch.repeat_interleave(x,3,1)

In [5]:
x = torch.randn(4,1,32,32)

In [6]:
model = torchvision.models.resnet50(True)
model.fc = nn.Linear(2048, 10)
model = model.to(device)

In [7]:
loss_func = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr = 0.001)

In [8]:
epoch_iter = trange(1)
for epoch in epoch_iter:
    # train
    num = 0
    den = 0
    data_iter = tqdm(dataloader)
    model.train()
    for data in data_iter:
        x = data[0].to(device)
        labels = data[1].to(device)
        
        y = model(x)
        loss = loss_func(y, labels)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        num += torch.sum(y.argmax(1) == labels)
        den += labels.shape[0]
        acc = float(num/den)
        
        data_iter.set_postfix(acc = acc, loss = float(loss))
    # test 
    num = 0
    den = 0
    data_iter = tqdm(dataloader_test)
    model.eval()
    for data in data_iter:
        x = data[0].to(device)
        labels = data[1].to(device)
        
        y = model(x)
        num += torch.sum(y.argmax(1) == labels)
        den += labels.shape[0]
        acc = float(num/den)
        
        data_iter.set_postfix(acc = acc)

  0%|          | 0/1 [00:00<?, ?it/s]
  0%|          | 0/59 [00:00<?, ?it/s][A
  0%|          | 0/59 [00:01<?, ?it/s, acc=0.103, loss=2.4][A
  2%|▏         | 1/59 [00:01<01:27,  1.50s/it, acc=0.103, loss=2.4][A
  2%|▏         | 1/59 [00:01<01:27,  1.50s/it, acc=0.224, loss=2.5][A
  3%|▎         | 2/59 [00:01<00:46,  1.24it/s, acc=0.224, loss=2.5][A
  3%|▎         | 2/59 [00:02<00:46,  1.24it/s, acc=0.359, loss=1.57][A
  5%|▌         | 3/59 [00:02<00:32,  1.72it/s, acc=0.359, loss=1.57][A
  5%|▌         | 3/59 [00:02<00:32,  1.72it/s, acc=0.479, loss=0.645][A
  7%|▋         | 4/59 [00:02<00:26,  2.09it/s, acc=0.479, loss=0.645][A
  7%|▋         | 4/59 [00:02<00:26,  2.09it/s, acc=0.566, loss=0.308][A
  8%|▊         | 5/59 [00:02<00:22,  2.40it/s, acc=0.566, loss=0.308][A
  8%|▊         | 5/59 [00:03<00:22,  2.40it/s, acc=0.626, loss=0.243][A
 10%|█         | 6/59 [00:03<00:20,  2.64it/s, acc=0.626, loss=0.243][A
 10%|█         | 6/59 [00:03<00:20,  2.64it/s, acc=0.673, loss