In [18]:
from utils.transforms import get_transform
from datasets.H2gigaDataset import H2gigaDataset
import torch
from models.BranchedERFNet import BranchedHyperNet
from criterions.my_loss import SpatialEmbLoss
from utils.utils import AverageMeter

dir = '../Data/H2giga/'

In [19]:
transfrom= get_transform([
                {
                    'name': 'RandomRotationsAndFlips',
                    'opts': {
                        'keys': ('image','hs', 'instance','label'),
                        'degrees': 90,
                    }
                },
                {
                    'name': 'ToTensor',
                    'opts': {
                        'keys': ('image', 'hs','instance', 'label'),
                        'type': (torch.FloatTensor,torch.FloatTensor, torch.ByteTensor, torch.ByteTensor),
                    }
                }
    ])
data = H2gigaDataset(dir,type='val',transform=transfrom)
sample = data.__getitem__(1)
loss_meter, iou_meter = AverageMeter(), AverageMeter()
device = 'cuda'
hs = sample['hs'].unsqueeze(0).to(device)
instance = sample['instance'].to(device)
label = sample['label'].to(device)
model = BranchedHyperNet(154,[4,5])
criterion = SpatialEmbLoss(class_weight=[1,1,1,1,1])

model.to(device)
criterion.to(device)
output = model(hs)
loss = criterion(output,instance,label,iou=True, iou_meter=iou_meter)
print(output.shape)


Creating branched hypernet with [4, 5] classes
Created spatial emb loss function with: to_center: True, n_sigma: 2
torch.Size([1, 9, 1024, 1024])


In [None]:
image = sample['image']
ins = sample['instance']
clmap = sample['label']

print(type(image),image.shape)
print(type(ins),ins.shape)
print(sample.keys())

In [None]:
import matplotlib.pyplot as plt
hs = sample['hs']
plt.imshow(hs.permute(1,2,0)[...,(30,50,80)])

print(hs.shape)

In [None]:
plt.imshow(ins.permute(1,2,0))


In [None]:
plt.imshow(clmap.permute(1,2,0))
print(clmap.shape)

In [None]:
import train_config
import torch
from tqdm import tqdm

from datasets import get_dataset
from models import get_model
from criterions.my_loss import SpatialEmbLoss



args =train_config.get_args()
device = torch.device("cuda:0" if args['cuda'] & torch.cuda.is_available() else "cpu")

train_dataset = get_dataset('H2giga', args['train_dataset']['kwargs'])
train_dataloader = torch.utils.data.DataLoader(
                                    train_dataset,
                                    batch_size=args['train_dataset']['batch_size'],
                                    shuffle=True,
                                    drop_last=True,
                                    num_workers=args['train_dataset']['workers'])
model = get_model(args['model']['name'], args['model']['kwargs'])
model.init_output(args['loss_opts']['n_sigma'])
model = torch.nn.DataParallel(model).to(device)
criterion = SpatialEmbLoss(**args['loss_opts'])
criterion = torch.nn.DataParallel(criterion).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=args['lr'], weight_decay=1e-4)


In [None]:


model.train()

for i, sample in enumerate(tqdm(train_dataloader)):

    im = sample['hs'].to(device)
#         im = sample['image'].to(device)

    instances = sample['instance'].squeeze().to(device)
    class_labels = sample['label'].squeeze().to(device)

    output = model(im)
    print(output.shape)
    loss = criterion(output,instances, class_labels)
    loss = loss.mean()

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

In [None]:
from utils.utils import AverageMeter
val_dataset = get_dataset(args['val_dataset']['name'], args['val_dataset']['kwargs'])
val_dataloader = torch.utils.data.DataLoader(
                                    val_dataset,
                                    batch_size=args['val_dataset']['batch_size'],
                                    shuffle=False,
                                    drop_last=True,
                                    num_workers=args['train_dataset']['workers'])
model.eval()
loss_meter, iou_meter = AverageMeter(), AverageMeter()

with torch.no_grad():

    for i, sample in enumerate(tqdm(val_dataloader)):

        im = sample['hs'].to(device)
#        im = sample['image'].to(device)

        instances = sample['instance'].squeeze().to(device)
        class_labels = sample['label'].squeeze().to(device)

        output = model(im)
        print(output.shape,instances.shape,class_labels.shape)
        loss = criterion(output,instances, class_labels, iou=True, iou_meter=iou_meter)
        loss = loss.mean()

