In [1]:
import torch
import numpy as np
import torchvision
from torchvision.models import vgg16
from torchvision.transforms import ToTensor,ToPILImage, Resize,Compose, Normalize
from torch.utils.data import DataLoader
from torch.autograd import Variable
from torch.optim import SGD,Adam
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset
from PIL import Image
import os
import matplotlib.pyplot as plt
#https://github.com/sairin1202/fcn32-pytorch/blob/master/pytorch-fcn32.py

In [2]:
class Relabel:
    def __init__(self,olabel,nlabel):
        self.olabel=olabel
        self.nlabel=nlabel

    def __call__(self,tensor):
        assert isinstance(tensor,torch.LongTensor) 
        tensor[tensor==self.olabel]=self.nlabel
        return tensor


class ToLabel:
    def __call__(self,image):
        return torch.from_numpy(np.array(image)).long().unsqueeze(0)

In [3]:
def load_image(file):
    return Image.open(file)

def get_images(filename):
    image_names=np.loadtxt(filename,dtype=np.str)[0:5]
    return image_names,len(image_names)

input_transform=Compose([
                        Resize((512,512)),
                        ToTensor(),
                        Normalize([0.5,0.5,0.5],[0.5,0.5,0.5])
])

label_transform=Compose([
                        Resize((512,512)),
                        ToLabel(),
                        Relabel(256,0)
])

class gdi(Dataset):
    
    def __init__(self,data_dir,label_dir,train_file,input_transform=None,label_transform=None):
        self.data_dir=data_dir
        self.label_dir=label_dir
        self.input_transform=input_transform
        self.label_transform=label_transform
        self.train_file=train_file
        self.train_image_names,self.train_length=get_images(train_file)

    def __getitem__(self,index):

        filename=self.train_image_names[index]
        with open(self.data_dir+str(filename)+'.jpg','rb') as f:
            image=load_image(f).convert('RGB')
        with open(self.label_dir+str(filename)+'.png','rb') as f:
            label=load_image(f).convert('L')

        if self.input_transform is not None:
            image=self.input_transform(image)
        if self.label_transform is not None:
            label=self.label_transform(label)
        
        return image,label
    
    def __len__(self):
        return len(self.train_image_names)

In [4]:
class fcn32(nn.Module):
    def __init__(self):
        super(fcn32,self).__init__()
        self.pretrained_model=vgg16(pretrained=True)
        features,classifiers=list(self.pretrained_model.features.children()),list(self.pretrained_model.classifier.children())

        features[0].padding=(100,100)
        self.features_map=nn.Sequential(*features)
        self.conv=nn.Sequential(nn.Conv2d(512,4096,7),
                                nn.ReLU(inplace=True),
                                nn.Dropout(),
                                nn.Conv2d(4096,4096,1),
                                nn.ReLU(inplace=True),
                                nn.Dropout()
                                )
        self.score_fr=nn.Conv2d(4096,256,1) 
        self.upscore=nn.ConvTranspose2d(256,256,64,32)

    def forward(self,x):
        x_size=x.size() #[1, 3, 512, 512]
        pool=self.conv(self.features_map(x)) #[1, 4096, 12, 12]
        score_fr=self.score_fr(pool) #[1, 256, 12, 12]
        upscore=self.upscore(score_fr) #[1,256,12,12]
        return upscore[:,:,16:(16+x_size[2]),16:(16+x_size[3])]

In [5]:
fcn=fcn32()
fcn=fcn.cuda()


data_dir="/kaggle/input/graphic-design-importance/gd_train/gd_train/"
train_txt_path="/kaggle/input/train-file/train.txt"
label_dir="/kaggle/input/graphic-design-importance/gd_imp_train/gd_imp_train/"

Downloading: "https://download.pytorch.org/models/vgg16-397923af.pth" to /root/.cache/torch/checkpoints/vgg16-397923af.pth


HBox(children=(FloatProgress(value=0.0, max=553433881.0), HTML(value='')))




In [6]:
def train(model,batch_size,epoches):
    model.train()

    #weight=torch.ones(21)
    #weight[0]=0

    loader=DataLoader(gdi(data_dir,label_dir,train_txt_path,input_transform,label_transform),
                      num_workers=4,batch_size=batch_size,shuffle=True)
    
    criterion=nn.CrossEntropyLoss()
    optimizer=Adam(model.parameters(),1e-5)

    for epoch in range(1,epoches+1):
        for step,(images,labels) in enumerate(loader):
            images=images.cuda()
            labels=labels.cuda()
            print('label:',labels.size())

            optimizer.zero_grad()
            inputs=Variable(images)
            targets=Variable(labels)
            targets = targets.squeeze(1)
            outputs=model(inputs)
            
            
            # inputs need to be of size [batch, n_class, dim1, dim2]
            # targets need to be of size [batch, dim1, dim2], in which values in dims are between [0,n_class)
            
            loss=criterion(outputs,targets)
            loss.backward()
            optimizer.step()
            if step%1==0:
                print(loss)


x size torch.Size([1, 3, 512, 512])

pool size torch.Size([1, 4096, 16, 16])

score size: torch.Size([1, 256, 16, 16])

upscore size torch.Size([1, 256, 544, 544])

what to return: torch.Size([1, 256, 512, 512])

output: torch.Size([1, 256, 512, 512])

target: torch.Size([1, 1, 512, 512])

In [7]:
train(fcn,3,40)

label: torch.Size([3, 1, 512, 512])
tensor(5.5452, device='cuda:0', grad_fn=<NllLoss2DBackward>)
label: torch.Size([2, 1, 512, 512])
tensor(5.5449, device='cuda:0', grad_fn=<NllLoss2DBackward>)
label: torch.Size([3, 1, 512, 512])
tensor(5.5443, device='cuda:0', grad_fn=<NllLoss2DBackward>)
label: torch.Size([2, 1, 512, 512])
tensor(5.5431, device='cuda:0', grad_fn=<NllLoss2DBackward>)
label: torch.Size([3, 1, 512, 512])
tensor(5.5419, device='cuda:0', grad_fn=<NllLoss2DBackward>)
label: torch.Size([2, 1, 512, 512])
tensor(5.5397, device='cuda:0', grad_fn=<NllLoss2DBackward>)
label: torch.Size([3, 1, 512, 512])
tensor(5.5371, device='cuda:0', grad_fn=<NllLoss2DBackward>)
label: torch.Size([2, 1, 512, 512])
tensor(5.5320, device='cuda:0', grad_fn=<NllLoss2DBackward>)
label: torch.Size([3, 1, 512, 512])
tensor(5.5262, device='cuda:0', grad_fn=<NllLoss2DBackward>)
label: torch.Size([2, 1, 512, 512])
tensor(5.5189, device='cuda:0', grad_fn=<NllLoss2DBackward>)
label: torch.Size([3, 1, 512, 