In [1]:
import os
import json
import wandb
import torch
import numpy as np


from torch import nn
from torch.nn import functional as F
from torch import optim
from torch.utils.data import DataLoader, Dataset

from torchvision import models
from torchvision import transforms

from matplotlib import pyplot as plt
from tqdm.notebook import trange

device = 'cuda' if torch.cuda.is_available else 'cpu'

In [2]:
class LabeledDataset(Dataset):
    def __init__(self, path2data='./data/train/labeled/', transform=None):
        self.transform = transform
        self.path = path2data
        walker = os.walk(path2data)
        _, classes, _ = next(walker)
        classes = sorted(classes)
        self.classes = classes
        self.files = []
        self.class2ind = {classes[i]:i for i in range(len(classes))}
        
        for path, folders, files in walker:
            for file in files:
                self.files.append(f'{path}/{file}')
        
        self.length_dataset = len(self.files)
        
    def __len__(self):
        return self.length_dataset

    def __getitem__(self, idx):
        filename = self.files[idx]
        classname = filename.split('/')[-2]
        class_idx = self.class2ind[classname]
        
        img = torch.tensor(plt.imread(filename), dtype=torch.float32)
        img = img.transpose_(0,2)
        
        if self.transform is not None:
            img = self.transform(img)
        
        return img, class_idx

In [3]:
class UnlabeledDataset:
    def __init__(self, path2data='./data/train/unlabeled/', transform=None):
        self.transform = transform
        self.path = path2data
        self.files = []
        for path, folders, files in os.walk(path2data):
            for file in files:
                self.files.append(f'{path}/{file}')
                
        _, _, self.files = next(walker)
        self.length_dataset = len(self.files)
        
    def __len__(self):
        return self.length_dataset

    def __getitem__(self, idx):
#         filename = self.path+'/'+self.files[idx]
        img = torch.tensor(plt.imread(self.files[idx]), dtype=torch.float32)
        img = img.transpose_(0,2)
        
        if self.transform is not None:
            img = self.transform(img)
        
        return img

In [4]:
normalization = transforms.Normalize([132., 126.4, 105.3], [67.8, 66.4, 70.5])

In [5]:
unldat = UnlabeledDataset('./data/', transform=normalization)
pretrain_loader = DataLoader(unldat, batch_size=10, shuffle=True, drop_last=True)

ValueError: num_samples should be a positive integer value, but got num_samples=0

In [7]:
len(unldat)

0

In [None]:
unldat = UnlabeledDataset('./data/train/unlabeled/', transform=normalization)
pretrain_loader = DataLoader(unldat, batch_size=10, shuffle=True, drop_last=True)

labdat = LabeledDataset('./data/train/labeled/', transform=normalization)
train_loader = DataLoader(labdat, batch_size=10, shuffle=True, drop_last=True)

In [None]:
model = models.resnet18(num_classes=10)

In [None]:
def forward_cut(X, model, level=-2):
    chs = list(model.children())[:level]
    res = X
    for ch in chs:
        X = res
        res = ch(X)
    return res

In [None]:
with open('versions.json', 'r') as f:
    versions = json.load(f)
versions['rotation-pretrain'] += 1
cur_ver = versions['rotation-pretrain']
with open('versions.json', 'w') as f:
    json.dump(versions, f)
    
wandb.init(project='Pretrain HW1', name=f'rotation-pretrain:{cur_ver}')

In [None]:
classifier = nn.Sequential(
    nn.Linear(in_features=512, out_features=4, bias=True),
)

In [None]:
model.to(device)
classifier.to(device)

opt = optim.Adam([*list(model.parameters()), *list(classifier.parameters())])

In [None]:
X = next(iter(pretrain_loader))
X.shape

In [None]:
with torch.no_grad():
    print(forward_cut(X.to(device), model, level=-1).shape)

In [None]:
epochs = 100

In [None]:
for epoch in trange(epochs):
    opt.param_groups[0]['lr'] *= 0.98
    for X in pretrain_loader:
        opt.zero_grad()
        X = X.to(device)
        rots = torch.randint(0, 4, (len(X),), device=device)
        for i in range(len(X)):
            X[i] = torch.rot90(X[i], rots[i], [1,2])
        
        hid = forward_cut(X, model, -1)
        hid.squeeze_(-1)
        hid.squeeze_(-1)
        pred = classifier(hid)
        
        loss = F.cross_entropy(pred, rots)
        
        loss.backward()
        opt.step()
        
        wandb.log({'loss':loss.item()})
    wandb.log({'lr':opt.param_groups[0]['lr']})

In [None]:
model = model.train(True)

In [None]:
#model.to(device)

final_opt = optim.Adam(model.parameters(), lr=5e-3)

In [None]:
for epoch in trange(50):
    final_opt.param_groups[0]['lr'] *= 0.96
    for X, y in train_loader:
        final_opt.zero_grad()
        X = X.to(device)
        y = y.to(device)
        
        pred = model(X)
        
        loss = F.cross_entropy(pred, y)
        
        loss.backward()
        final_opt.step()
        
        wandb.log({'final loss':loss.item()})
    wandb.log({'final lr':final_opt.param_groups[0]['lr']})

In [None]:
model = model.train(False)

In [None]:
infer_path, _, infer_files = next(os.walk('./data/test/'))

ans = []
classes_stat = []
for file in infer_files:
    filename = infer_path+'/'+file
    img = torch.tensor(plt.imread(filename), dtype=torch.float32)
    img = img.transpose_(0,2).unsqueeze(0)

    img = normalization(img)
    with torch.no_grad():
        class_idx = model(img.to(device)).argmax()
    classes_stat.append(class_idx.item())
    ans.append((file, labdat.classes[class_idx.item()]))

In [None]:
plt.hist(classes_stat, bins=10);

In [None]:
with open('./result_rotation-pretrain.csv', 'w') as f:
    print('id,class', file=f)
    for file, classn in ans:
        print(f'{file},{classn}', file=f)

In [None]:
labdat.classes