In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

import torchvision.transforms as transforms
import torchvision.models as models

from skimage import io, transform
import numpy as np
import pandas as pd
import os
from PIL import Image

In [2]:
NUM_EPOCHS = 10
LR = 1e-3
BATCH_SIZE = 4

In [3]:
tfms = transforms.Compose([transforms.Resize(256), transforms.RandomCrop(224), transforms.ToTensor()])

In [4]:
class AliensDataset(Dataset):

    def __init__(self, category_file, root_dir, num, transform=None):
        self.categories = pd.read_csv(os.path.join(root_dir,category_file))["category"]
        self.root_dir = root_dir
        self.transform = transform
        self.num = num

    def __len__(self):
        return len(self.categories)

    def __getitem__(self, idx):
        img_name = os.path.join(self.root_dir+"/training/", str(idx+1) + ".png")
        image = Image.open(img_name)
        if self.transform:
            image = self.transform(image)
        category = self.categories[idx]
        sample = {'image': image, 'category': category}

        return sample

In [19]:
aliens_dataset = AliensDataset("solution.csv", "/home/krypt/myStuff/skillenza/paniithackathon2019/training.5k/training", 5000, tfms)

In [20]:
dataloader = DataLoader(aliens_dataset, batch_size=BATCH_SIZE, shuffle=True)

In [21]:
class Flatten(nn.Module):
    def forward(self, input):
        return input.view(input.size(0), -1)

In [22]:
model = models.resnet34(pretrained=True)

In [23]:
model = nn.Sequential(*list(model.children())[:-2],
                     nn.Conv2d(512, 6, 3, 1),
                     nn.AdaptiveAvgPool2d(1), Flatten(),
                     nn.Softmax())

if torch.cuda.is_available():
    model = model.cuda()

criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=LR)

In [24]:
list(model.children())

[Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False),
 BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
 ReLU(inplace),
 MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False),
 Sequential(
   (0): BasicBlock(
     (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
     (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
     (relu): ReLU(inplace)
     (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
     (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
   )
   (1): BasicBlock(
     (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
     (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
     (relu): ReLU(inplace)
     (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bi

In [25]:
for epoch in range(NUM_EPOCHS):
    total_loss = 0
    accuracy = 0
    count = 0
    for c, data in enumerate(dataloader):
        optimizer.zero_grad()

        images = data["image"]
        targets = torch.zeros((images.size(0), 6))
        for i in range(len(data["category"])):
            targets[i][data["category"][i]-1] = 1

        if torch.cuda.is_available():
            images = images.cuda()
            targets = targets.cuda()
        
        outs = model(images)
        loss = criterion(outs, targets)
        loss.backward()
        optimizer.step()

        
        total_loss += loss.data
        accuracy += torch.sum(torch.argmax(outs, dim=1) == (data["category"].cuda()-1))
        count += data["image"].size(0)

        if c%200 == 0:
            print("Epoch:", epoch, "Iter:", c, "Total loss:", total_loss, "Accuracy:", accuracy.cpu().numpy()/(count+1))
                
    print("Epoch:", epoch, "Total loss:", total_loss)
    print("Epoch:", epoch, "Accuracy:", accuracy/count)    

TypeError: batch must contain tensors, numbers, dicts or lists; found <class 'PIL.PngImagePlugin.PngImageFile'>