In [10]:
from google.colab import drive
from google.colab import files
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [0]:
!pwd

/content/drive/My Drive/scanntech


In [0]:
ls

[0m[01;34mcheckpoints[0m/  [01;34mdata[0m/  [01;34mengine[0m/  log.txt  [01;34mmodel[0m/  train.ipynb  train.py  [01;34mutils[0m/


In [0]:
cd scanntech/

/content/drive/My Drive/scanntech


In [0]:
import torch
from torch import nn
from torch.utils.data import DataLoader,Dataset
from torchvision import transforms
from pathlib import Path
import pandas as pd
from PIL import Image
import torchvision
from torch.nn import Module
import matplotlib.pyplot as plt

In [0]:
class MainDataset(Dataset):
    def __init__(self, dataset_path, split='train', transform=None, normalize=False, max_num = None):
        super(MainDataset, self).__init__()
        self.normalize = None

        self.dataset_path = dataset_path
        self.split = split

        path = Path(self.dataset_path)/str(self.split+'.nyu')
        with open(str(path), 'r') as f:
            self.list_set = f.read()

        self.list_set = self.splitting(self.list_set)
        
        
        if transform is None:
            self.transform = transforms.Compose([#transforms.CenterCrop(1600),
                                                 transforms.ToTensor()])
        if normalize:
            self.normalize = transforms.Compose([transforms.Normalize([.485, .456, .406], [.229, .224, .225])])
        if max_num is not None:
            self.list_set = self.list_set.sample(max_num)

    def __getitem__(self, indx):
        sample = self.list_set.iloc[indx]

        image = Image.open(str(self.dataset_path/Path('nyud')/Path(sample['image'])))
        label = Image.open(str(self.dataset_path/Path('nyud')/Path(sample['label'])))

        image = self.transform(image)
        label = self.transform(label)

        if self.normalize is not None:
            image = self.normalize(image)

        return {'image': image, 'label': label}

    def __len__(self):
        return len(self.list_set)

    def splitting(self, set):
        names = set.split('\n')

        images = []
        labels = []
        for name in names:
            if name == '':
                continue
            paths = name.split('\t')
            images.append(paths[0])
            labels.append(paths[1])

        return pd.DataFrame({'image': images, 'label': labels})

In [0]:
class MainModel(Module):
    def __init__(self, size, same_size_output=False):
        super(MainModel, self).__init__()

        self.backbone = ResNetBackbone()
        self.head = FPNHead(size, same_size_output=same_size_output)
        self.depthhead = ConvBlock(256, 1, 1, 3)

    def forward(self, image):
        x = self.backbone.get_features(image)
        out = self.head(x)
        out = self.depthhead(out)
        return out


class ResNetBackbone(Module):
    def __init__(self):
        super(ResNetBackbone, self).__init__()
        self.net = torchvision.models.resnet18(pretrained=True)

    def get_features(self, x):
        x = self.net.conv1(x)
        x = self.net.bn1(x)
        x = self.net.relu(x)
        x = self.net.maxpool(x)
        c1 = self.net.layer1(x)
        c2 = self.net.layer2(c1)
        c3 = self.net.layer3(c2)
        c4 = self.net.layer4(c3)

        return c1, c2, c3, c4

    def forward(self, x):
        out = self.net.forward(x)
        return out


class FPNHead(Module):
    def __init__(self, size, same_size_output=False):
        super(FPNHead, self).__init__()
        self.same_size_output = same_size_output
        self.size = size
        self.block4 = ConvBlock(in_channels=512, out_channels=256, kernel_size=1, num_layer=1)
        self.block3 = ConvBlock(in_channels=256, out_channels=256, kernel_size=1, num_layer=1)
        self.block2 = ConvBlock(in_channels=128, out_channels=256, kernel_size=1, num_layer=1)
        self.block1 = ConvBlock(in_channels=64, out_channels=256, kernel_size=1, num_layer=1)

        if self.same_size_output:
            self.upsample4 = nn.Upsample(size=size[3], mode='bilinear', align_corners=True)
            self.upsample3 = nn.Upsample(size=size[2], mode='bilinear', align_corners=True)
            self.upsample2 = nn.Upsample(size=size[1], mode='bilinear', align_corners=True)
            self.upsample1 = nn.Upsample(size=size[0], mode='bilinear', align_corners=True)

        else:
            self.upsample4 = nn.Upsample(size=size[2], mode='bilinear', align_corners=True)
            self.upsample3 = nn.Upsample(size=size[1], mode='bilinear', align_corners=True)
            self.upsample2 = nn.Upsample(size=size[0], mode='bilinear', align_corners=True)

    def forward(self, features):
        c4 = self.block4(features[3])
        c3 = self.upsample4(c4) + self.block3(features[2])
        c2 = self.upsample3(c3) + self.block2(features[1])
        c1 = self.upsample2(c2) + self.block1(features[0])

        if self.same_size_output:
            c1 = self.upsample1(c1)

        return c1


class ConvBlock(Module):
    def __init__(self, in_channels=128, out_channels=256,
                 kernel_size=1, num_layer=1):
        super(ConvBlock, self).__init__()

        self.block = nn.Sequential()

        self.block.add_module('conv_1', nn.Conv2d(in_channels=in_channels,
        out_channels=out_channels, kernel_size=kernel_size))
        self.block.add_module('bn_1', nn.BatchNorm2d(out_channels))
        self.block.add_module('relu_1', nn.ReLU())

        for i in range(1, num_layer):
            self.block.add_module(f'conv_{i+1}', nn.Conv2d(in_channels=out_channels,
            out_channels=out_channels, kernel_size=kernel_size))

            self.block.add_module(f'bn_{i+1}', nn.BatchNorm2d(out_channels))
            self.add_module(f'relu_{i+1}', nn.ReLU())

    def forward(self, x):
        out = self.block(x)
        return out

In [0]:
def logging(path,par, str):
  with open(path, par) as f:
    f.write(str)

In [18]:
sizes = []
'''
trainset = MainDataset('/content/drive/My Drive/scanntech/data/', 'train', 
                      normalize=True)
valset = MainDataset('/content/drive/My Drive/scanntech/data/', 'val', 
                      normalize=True)

trainloader = DataLoader(trainset, batch_size=8)
valloader = DataLoader(valset, batch_size=8)
'''
model = MainModel(size=[0, 0, 0, 0])
#sample = trainset.__getitem__(0)


input_transform = torchvision.transforms.Compose([transforms.ToTensor()])#torchvision.transforms.CenterCrop(1600),
image = Image.open('/content/drive/My Drive/scanntech/333.jpg')
img = input_transform(image)

#img = sample['image']
c1, c2, c3, c4 = model.backbone.get_features(img[None, :, :, :])

sizes.append(img.shape[1:3])
sizes.append(c1.shape[2:4])
sizes.append(c2.shape[2:4])
sizes.append(c3.shape[2:4])
sizes.append(c4.shape[2:4])


print(sizes)

[torch.Size([480, 640]), torch.Size([120, 160]), torch.Size([60, 80]), torch.Size([30, 40]), torch.Size([15, 20])]


In [0]:
model = MainModel(size=sizes, same_size_output=True)
logging('/content/drive/My Drive/scanntech/log.txt','w', 'Start:\n')

device = 'cpu'
if torch.cuda.is_available():
  device = 'cuda:0'

optimizer = torch.optim.Adam([{'params': model.depthhead.parameters(), 'lr':10e-3},
                              {'params': model.head.parameters(), 'lr':10e-3}], lr=10e-3)

model = model.to(device)
epoch_num = 100
loss_function = nn.MSELoss()
val_score = []

for epoch in range(epoch_num):
  optimizer.zero_grad()
  epoch_loss = 0
  for indx, batch_sample in enumerate(trainloader):
    output = model(batch_sample['image'].to(device))
    loss = loss_function(batch_sample['label'].to(device), output)
    loss.backward()
    epoch_loss += loss.cpu().detach().numpy()
  optimizer.step()
  
  val_error = 0
  for indx, batch_sample in enumerate(valloader):
    output = model(batch_sample['image'].to(device))
    val_error += loss_function(batch_sample['label'].to(device), output).cpu().detach().numpy()
  val_score.append(val_error)
  
  log = f'Epoch: {epoch}, train loss:{epoch_loss}, val loss:{val_error}\n'
  logging('/content/drive/My Drive/scanntech/log.txt','a+', log)
  
  print(f'Epoch: {epoch}, train loss:{epoch_loss}, val loss:{val_error}')
  torch.save(model.state_dict(), f'/content/drive/My Drive/scanntech/checkpoints/checkpoint_epoch_{epoch}.pth')
  
  plt.plot(val_score)
  plt.title('val_error')
  plt.show()

