### Imports & constants

In [1]:
import torch
from torchvision.utils import make_grid
from torch.utils.data import DataLoader, Subset

from synthesis_dataset import SynthesisDataset

from net import Cycle
from bdcn import BDCN

import numpy as np

In [2]:
np.random.seed(42)

torch.backends.cudnn.benchmark = True
torch.backends.cudnn.enabled = True

device = torch.device('cuda')
batch_size = 2

In [11]:
def combined_loss(preds, targets):
    mask = (targets[0] > 0.5).float()
    b, c, h, w = mask.shape
    num_p = torch.sum(mask, dim=[1, 2, 3])
    num_n = c * h * w - num_p
    weight = torch.empty_like(mask)
    for i, w in enumerate(weight):
        w[mask[i] == 0] = num_p[i] * 1.2 / (num_p[i] + num_n[i])
        w[mask[i] != 0] = num_n[i] / (num_p[i] + num_n[i])

    losses = []
    for pred in preds[:-1]:
        losses += [args.side_weight * F.binary_cross_entropy_with_logits(
            pred.float(), targets[0].float(), weight=weight, reduction='sum')]
    losses += [args.fuse_weight * F.binary_cross_entropy_with_logits(
        preds[-1].float(), targets[0].float(), weight=weight, reduction='sum')]

    return sum([loss / b for loss in losses])

### Dataloaders

In [3]:
dataset = SynthesisDataset("..\simulation-synthesis\output\MLDataset", extension='.png')
dataset.modalities = ['img', 'outlines']

rtrain = Subset(dataset, range(0, int(0.9*len(dataset))))
rval = Subset(dataset, range(int(0.9*len(dataset)), len(dataset)))

train_loader = DataLoader(rtrain, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(rval, batch_size=batch_size, shuffle=False)

In [4]:
model = BDCN(pretrained=True).to(device)
model.train()
optimizer = torch.optim.SGD(params=model.parameters(), momentum=0.9, lr=1e-6, weight_decay=0.0002)

  nn.init.constant(param, 0.080)


In [15]:
for data in train_loader:
    optimizer.zero_grad()
    images = data['img'].to(device)
    labels = data['outlines'].to(device)
    out = model(images)
    # loss = combined_loss(out, labels)
    # loss.backward()
    # batch_loss += loss.item()
    break