In [None]:
import os
import numpy as np
import torch

#import torchvision.transforms as transforms
from torch import nn
from torchvision import datasets
from torchvision import transforms

In [None]:
import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
path = '/scratch/ssd/cciw/dataset_voc/'
file = '1349_2016-07-06_2_GLN_3061_crop'

label_path = os.path.join(os.path.join(path, 'SegmentationClassPNG'), file)
image_path = os.path.join(os.path.join(path, 'JPEGImages'), file)

In [None]:
transform = transforms.Compose([transforms.ToTensor()])

In [None]:
# for debugging purposes
# dlpath = '/scratch/gallowaa/'
# trainset = ds.VOCSegmentation(dlpath, image_set='val', download=True, transform=transform)
trainset = datasets.VOCSegmentation(
    root='/scratch/ssd/cciw/', year='2012', image_set='train',
    download=False, transform=transform, target_transform=transform
)

In [None]:
trainloader = torch.utils.data.DataLoader(trainset, batch_size=10, shuffle=True)

In [None]:
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

In [None]:
torch.cuda.device_count()

In [None]:
use_cuda = torch.cuda.is_available()
print(use_cuda)

In [None]:
for inputs, targets in trainloader:
    inputs, targets = inputs.cuda(), targets.cuda()
print(inputs.shape)
print(targets.shape)

In [None]:
idx = 0
plt.imshow(inputs.detach().cpu().numpy()[idx, 0])
#plt.imshow(targets.detach().cpu().numpy()[idx, 0])

In [None]:
np.unique(targets.detach().cpu().numpy()[1, 0])

In [None]:
%matplotlib inline

In [None]:
#1/255

In [None]:
#from torchvision.models.segmentation import fcn_resnet50

In [None]:
#net = fcn_resnet50(num_classes=1).cuda()

In [None]:
from torchvision.models import segmentation as models

In [None]:
model_names = sorted(name for name in models.__dict__
    if name.islower() and not name.startswith("__")
    and callable(models.__dict__[name]))
print(model_names)

In [None]:
resume = '/scratch/gallowaa/cciw/logs/checkpoint/fcn_resnet50_bs50_wd5e-04_def_1.ckpt'
print('==> Resuming from checkpoint..')
checkpoint = torch.load(resume)
net = checkpoint['net']
best_acc = checkpoint['loss']
start_epoch = checkpoint['epoch'] + 1
torch.set_rng_state(checkpoint['rng_state'])

In [None]:
best_acc

In [None]:
net = segmentation.__dict__['fcn_resnet50'](num_classes=1)
net = net.cuda()

In [None]:
'%.e' % 5e-4

In [None]:
#net.eval()
net.train()
print(net.training)

In [None]:
pred = net(inputs)['out']
pred.shape

In [None]:
targets.shape

In [None]:
inputs_nhwc = inputs.permute(0, 2, 3, 1)

In [None]:
idx = 0
fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(8, 3))
ax1.imshow(inputs_nhwc.detach().cpu().numpy()[idx])
ax2.imshow(pred.detach().cpu().numpy()[idx, 0])
ax3.imshow(targets.detach().cpu().numpy()[idx, 0])
ax1.set_title('Image')
ax2.set_title('Logits')
ax3.set_title('Labels')
ax1.axis('off')
ax2.axis('off')
ax3.axis('off')
plt.tight_layout()

In [None]:
#pred.min()

In [None]:
#plt.imshow(pred.detach().cpu().numpy()[idx, 1])

In [None]:
#plt.imshow(torch.argmax(pred, dim=1).detach().cpu().numpy()[idx])

In [None]:
pred.shape

In [None]:
pred.max()

In [None]:
np.log(2)

In [None]:
(targets * 255).max()

In [None]:
loss_fn(pred, targets * 255)

In [None]:
loss_fn = nn.BCEWithLogitsLoss()

In [None]:
optimizer = torch.optim.SGD(net.parameters(), lr=0.1, weight_decay=5e-4)

In [None]:
epochs = 1
for epoch in range(epochs):
    for batch, (inputs, targets) in enumerate(trainloader):
        inputs, targets = inputs.cuda(), targets.cuda()
        pred = net(inputs)['out']
        loss = loss_fn(pred, targets * 255)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    print('Epoch [{}/{}], Loss: {:.4f}'
          .format(epoch + 1, epochs, loss.item()))

In [None]:
#pixel_acc(targets, targets)
pixel_acc(pred, pred)

In [None]:
targets.shape

In [None]:
def pixel_acc(pred, label):
    #_, preds = torch.max(pred, dim=1)
    preds = torch.argmax(pred, dim=1)
    valid = (label >= 0).long()
    acc_sum = torch.sum(valid * (preds == label).long())
    pixel_sum = torch.sum(valid)
    acc = acc_sum.float() / (pixel_sum.float() + 1e-10)
    return acc

In [None]:
class SegmentationModuleBase(nn.Module):
    def __init__(self):
        super(SegmentationModuleBase, self).__init__()

    def pixel_acc(self, pred, label):
        _, preds = torch.max(pred, dim=1)
        valid = (label >= 0).long()
        acc_sum = torch.sum(valid * (preds == label).long())
        pixel_sum = torch.sum(valid)
        acc = acc_sum.float() / (pixel_sum.float() + 1e-10)
        return acc


class SegmentationModule(SegmentationModuleBase):
    def __init__(self, net_enc, net_dec, crit, deep_sup_scale=None):
        super(SegmentationModule, self).__init__()
        self.encoder = net_enc
        self.decoder = net_dec
        self.crit = crit
        self.deep_sup_scale = deep_sup_scale

    def forward(self, feed_dict, *, segSize=None):
        # training
        if segSize is None:
            if self.deep_sup_scale is not None: # use deep supervision technique
                (pred, pred_deepsup) = self.decoder(self.encoder(feed_dict['img_data'], return_feature_maps=True))
            else:
                pred = self.decoder(self.encoder(feed_dict['img_data'], return_feature_maps=True))

            loss = self.crit(pred, feed_dict['seg_label'])
            if self.deep_sup_scale is not None:
                loss_deepsup = self.crit(pred_deepsup, feed_dict['seg_label'])
                loss = loss + loss_deepsup * self.deep_sup_scale

            acc = self.pixel_acc(pred, feed_dict['seg_label'])
            return loss, acc
        # inference
        else:
            pred = self.decoder(self.encoder(feed_dict['img_data'], return_feature_maps=True), segSize=segSize)
            return pred