In [1]:
import os
import torch
import torch.nn as nn
import torchvision.models
import collections
import math
import torch.nn.functional as F
import imagenet.mobilenet
from torch import optim
from pathlib import Path

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def weights_init(m):
    # Initialize kernel weights with Gaussian distributions
    if isinstance(m, nn.Conv2d):
        n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
        m.weight.data.normal_(0, math.sqrt(2. / n))
        if m.bias is not None:
            m.bias.data.zero_()
    elif isinstance(m, nn.ConvTranspose2d):
        n = m.kernel_size[0] * m.kernel_size[1] * m.in_channels
        m.weight.data.normal_(0, math.sqrt(2. / n))
        if m.bias is not None:
            m.bias.data.zero_()
    elif isinstance(m, nn.BatchNorm2d):
        m.weight.data.fill_(1)
        m.bias.data.zero_()


class DepthwiseConv(nn.Module):
    def __init__(self, in_channels, kernel_size):
        super().__init__()
        
        padding = (kernel_size-1) // 2
        assert 2  * padding == kernel_size-1, f"parameters incorrect. kernel={kernel_size}, padding={padding}"
        
        self.net = nn.Sequential(
            nn.Conv2d(in_channels=in_channels, out_channels=in_channels, 
                      kernel_size=kernel_size, padding=padding, stride=1, bias=False),
            nn.BatchNorm2d(in_channels),
            nn.ReLU(inplace=True)
        )
        
        # add init weights
    
    def forward(self, x):
        return self.net(x)
    

class PointwiseConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        
        self.net = nn.Sequential(
            nn.Conv2d(in_channels=in_channels, out_channels=out_channels, 
            kernel_size=1, padding=0, stride=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
        
        # add init weights
        
    def forward(self, x):
        return self.net(x)
    
class ConvDecomposed(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size):
        super().__init__()
        
        self.net = nn.Sequential(
            DepthwiseConv(in_channels=in_channels, kernel_size=kernel_size),
            PointwiseConv(in_channels=in_channels, out_channels=out_channels)
        )
        
        # add init weights
        
    def forward(self, x):
        return self.net(x)
        
        
# use this in mobilenet.py
# replace .view() in mobilent.py

class NNConv5(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, interpolation_scale_factor):
        super().__init__()
        
        self.interpolation_scale_factor = interpolation_scale_factor
        
        self.net = ConvDecomposed(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size)
        
    def forward(self, x):
        x = self.net(x)
        return F.interpolate(x, scale_factor=self.interpolation_scale_factor, mode='nearest')
    
    
class Model(nn.Module):
    def __init__(self, pretrained=True, decoder_kernel_size=5, decoder_interpolation_scale_factor=2):
        super().__init__()
        
        mobilenet = imagenet.mobilenet.MobileNet()
        if pretrained:
            pretrained_path = os.path.join('cifar100.pth')
            checkpoint = torch.load(pretrained_path, map_location='cpu')
            state_dict = checkpoint.state_dict()

            from collections import OrderedDict
            new_state_dict = OrderedDict()
            for k, v in state_dict.items():
                name = k[7:] # remove `module.`
                new_state_dict[name] = v
            #mobilenet.load_state_dict(new_state_dict)
            mobilenet.load_state_dict(state_dict)
        else:
            mobilenet.apply(weights_init)

        for i in range(14):
            setattr( self, 'conv{}'.format(i), mobilenet.model[i])
            
        self.decode_conv1 = NNConv5(in_channels=1024, out_channels=512, 
                                    kernel_size=decoder_kernel_size,
                                    interpolation_scale_factor=decoder_interpolation_scale_factor)
       
        self.decode_conv2 = NNConv5(in_channels=512, out_channels=256, 
                                    kernel_size=decoder_kernel_size,
                                    interpolation_scale_factor=decoder_interpolation_scale_factor)
        
        self.decode_conv3 = NNConv5(in_channels=256, out_channels=128, 
                                    kernel_size=decoder_kernel_size,
                                    interpolation_scale_factor=decoder_interpolation_scale_factor)
    
        self.decode_conv4 = NNConv5(in_channels=128, out_channels=64, 
                                    kernel_size=decoder_kernel_size,
                                    interpolation_scale_factor=decoder_interpolation_scale_factor)
        
        self.decode_conv5 = NNConv5(in_channels=64, out_channels=32, 
                                    kernel_size=decoder_kernel_size,
                                    interpolation_scale_factor=decoder_interpolation_scale_factor)
        
        self.decode_conv6 = PointwiseConv(in_channels=32, out_channels=1)
        weights_init(self.decode_conv1)
        weights_init(self.decode_conv2)
        weights_init(self.decode_conv3)
        weights_init(self.decode_conv4)
        weights_init(self.decode_conv5)
        weights_init(self.decode_conv6)

    def forward(self, x):
        # skip connections: dec4: enc1
        # dec 3: enc2 or enc3
        # dec 2: enc4 or enc5
        for i in range(14):
            layer = getattr(self, f'conv{i}')
            x = layer(x)
            if i==1:
                x1 = x
            elif i==3:
                x2 = x
            elif i==5:
                x3 = x
        for i in range(1,6):
            layer = getattr(self, f'decode_conv{i}')
            x = layer(x)
            if i==4:
                x = x + x1
            elif i==3:
                x = x + x2
            elif i==2:
                x = x + x3
        x = self.decode_conv6(x)
        return x

In [3]:
from dataloaders.nyu import NYUDataset

In [4]:
def train(dataset, net=None, criterion=None, optimizer=None, batch_size=8, lr=3e-4, epochs=20, device=None):
    log = []
    if device is not None:
        net.to(device)
    
    if optimizer is None:
        optimizer = optim.Adam(net.parameters(), lr=lr)

    trainloader = torch.utils.data.DataLoader(
        dataset, batch_size=batch_size, shuffle=True, num_workers=2
    )
    
    stats_step = (len(dataset) // 10 // batch_size) + 1
    for epoch in range(epochs):

        running_loss = 0.0
        for i, data in enumerate(trainloader, 0):
            inputs, targets = data
            if device is not None:
                inputs = inputs.to(device)
                targets = targets.to(device)
            optimizer.zero_grad()
            outputs = net(inputs)
            losses = criterion(outputs, targets).mean(axis=0)
            loss_value = losses.sum()
            if torch.isnan(loss_value).any():
                warnings.warn("nan loss! skip update")
                print(f"last loss: {[l.item() for l in losses]}")
                break
            running_loss += loss_value
            if (i % stats_step == 0):
                print(f"epoch {epoch}|{i}; total loss:{running_loss / stats_step}")
                print(f"last losses: {[l.item() for l in losses.flatten()]}")
                log.append([l.item() for l in losses.flatten()])
                running_loss = 0.0
            loss_value.backward()
            optimizer.step()
    print('Finished Training')
    return net, log

# CIFAR pretrain

In [None]:
from torchvision import transforms

In [None]:
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    transforms.Resize(228)])

imagenet_data = torchvision.datasets.CIFAR100('/DATA/vashchilko/', train=True, transform=transform)

In [None]:
net = imagenet.mobilenet.MobileNet()
lr = 0.001
criterion = nn.CrossEntropyLoss()
batch_size = 64
epochs = 50
device = torch.device('cuda:1')

In [None]:
trained_net, log = train(imagenet_data, net=net, criterion=criterion,
                    batch_size=batch_size, lr=lr,
                    epochs=epochs, device=device)

In [None]:
torch.save(trained_net, 'cifar100.pth')

# Training on NYUv2

In [5]:
base_nyu = Path("/DATA/vashchilko/nyudepthv2")
traindir_nyu = base_nyu / 'train'
valdir_nyu = base_nyu / 'val'

assert traindir_nyu.exists(), "Set your own path to train"
assert valdir_nyu.exists(), "Set your own path to val"

In [6]:
train_dataset = NYUDataset(traindir_nyu, split='train')

In [7]:
net = Model(pretrained=True)
lr = 0.01
optimizer = optim.SGD(net.parameters(), lr=lr, momentum=0.9, weight_decay=0.0001)
criterion = nn.MSELoss()
batch_size = 8
epochs = 12
device = torch.device('cuda:1')

In [8]:
%%time
trained_net, log = train(train_dataset, net=net, criterion=criterion,
                    batch_size=batch_size, lr=lr,
                    epochs=epochs, device=device)

epoch 0|0; total loss:0.007855179719626904
last losses: [4.6188459396362305]
epoch 0|588; total loss:1.5750302076339722
last losses: [0.8299457430839539]
epoch 0|1176; total loss:1.1929091215133667
last losses: [0.6045311689376831]
epoch 0|1764; total loss:1.1383962631225586
last losses: [0.6842315196990967]
epoch 0|2352; total loss:1.1238526105880737
last losses: [0.999713659286499]
epoch 0|2940; total loss:1.1364283561706543
last losses: [1.2110828161239624]
epoch 0|3528; total loss:1.1381560564041138
last losses: [0.9315807223320007]
epoch 0|4116; total loss:1.082692265510559
last losses: [1.0542418956756592]
epoch 0|4704; total loss:1.0908362865447998
last losses: [1.9334301948547363]
epoch 0|5292; total loss:1.0529404878616333
last losses: [0.3099454641342163]
epoch 1|0; total loss:0.002685386687517166
last losses: [1.5790073871612549]
epoch 1|588; total loss:1.037062168121338
last losses: [1.1000579595565796]
epoch 1|1176; total loss:1.047691822052002
last losses: [0.675275921821

epoch 10|3528; total loss:0.7228920459747314
last losses: [0.5855341553688049]
epoch 10|4116; total loss:0.7226430773735046
last losses: [0.3254247307777405]
epoch 10|4704; total loss:0.7037139534950256
last losses: [0.6441723704338074]
epoch 10|5292; total loss:0.688650906085968
last losses: [0.648614227771759]
epoch 11|0; total loss:0.0006703315884806216
last losses: [0.39415499567985535]
epoch 11|588; total loss:0.6803773641586304
last losses: [1.0145773887634277]
epoch 11|1176; total loss:0.686488926410675
last losses: [0.7559367418289185]
epoch 11|1764; total loss:0.6973398923873901
last losses: [0.3075387179851532]
epoch 11|2352; total loss:0.6868622303009033
last losses: [0.27569329738616943]
epoch 11|2940; total loss:0.7018951177597046
last losses: [0.48289287090301514]
epoch 11|3528; total loss:0.6761839985847473
last losses: [0.33826375007629395]
epoch 11|4116; total loss:0.6826497316360474
last losses: [0.3034030795097351]
epoch 11|4704; total loss:0.6746731400489807
last lo

In [9]:
torch.save(trained_net, 'trained1.pth')

In [10]:
torch.save(trained_net.state_dict(), 'trained1_state.pth')