In [1]:
from matplotlib import pyplot as plt

import numpy as np
import torch
import torch.nn as nn
import tensorboard 

from pytorch_lightning import seed_everything, LightningModule, Trainer
import multiprocessing
import torchmetrics

from torch.utils.data import DataLoader, Dataset
from utils import encode_segmap , decode_segmap, get_nClasses

import pytorch_lightning as pl
from pytorch_lightning.strategies import DDPStrategy
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks import EarlyStopping,ModelCheckpoint,LearningRateMonitor
from pytorch_lightning.loggers import TensorBoardLogger

import albumentations as A
from albumentations.pytorch import ToTensorV2


import json
from dataset import CityscapesDataset
from segmentationModel import SemanticSegmentationModel


  warn(f"Failed to load image Python extension: {e}")


In [2]:
parameter = json.load(open("config.json"))

## training 

In [3]:
model  = SemanticSegmentationModel(architecture="ConvMixer")
checkpoint_callback = ModelCheckpoint(monitor='val_loss',dirpath='checkpoints',
                                          filename='file',save_last=True, every_n_epochs=1, save_top_k=5)
logger = TensorBoardLogger("tb_logs", name="my_model")
lr_monitor = LearningRateMonitor(logging_interval='epoch')
# trainer = pl.Trainer(max_epochs=30, auto_lr_find=False, auto_scale_batch_size=False,
#                    gpus='0',precision=16,
#                    callbacks=[checkpoint_callback, lr_monitor], 
#                    logger=logger
#                   )
trainer = pl.Trainer(gpus=1, max_epochs=parameter["epochs"], callbacks=[lr_monitor])
trainer.fit(model )

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name      | Type          | Params
--------------------------------------------
0 | layer     | ConvMixerUnet | 23.4 M
1 | criterion | DiceLoss      | 0     
2 | metrics   | JaccardIndex  | 0     
--------------------------------------------
23.4 M    Trainable params
0         Non-trainable params
23.4 M    Total params
93.633    Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")


## Testing

In [None]:
model  = SemanticSegmentationModel()
model.load_from_checkpoint("checkpoints/best.ckpt")

In [None]:
test_class = CityscapesDataset('./data/', split='val', mode='fine',
                     target_type='semantic',transforms=model.transform)
test_loader=DataLoader(test_class, batch_size=12, 
                      shuffle=False)

In [None]:
transform=A.Compose(
[
    A.Resize(256, 512),
    A.HorizontalFlip(),
    A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
    ToTensorV2(),
]
)

In [None]:
trainer.test(model)

In [None]:
model=model.cuda()
model.eval()
with torch.no_grad():
    for batch in test_loader:
        img,seg=batch
        output=model(img.cuda())
        break
print(img.shape,seg.shape,output.shape)    

In [None]:
from torchvision import transforms
inv_normalize = transforms.Normalize(
    mean=[-0.485/0.229, -0.456/0.224, -0.406/0.225],
    std=[1/0.229, 1/0.224, 1/0.255]
)


In [None]:
sample=11
invimg=inv_normalize(img[sample])
outputx=output.detach().cpu()[sample]
encoded_mask=encode_segmap(seg[sample].clone()) #(256, 512)
decoded_mask=decode_segmap(encoded_mask.clone())  #(256, 512)
decoded_ouput=decode_segmap(torch.argmax(outputx,0))
fig,ax=plt.subplots(ncols=3,figsize=(16,50),facecolor='white')  
ax[0].imshow(np.moveaxis(invimg.numpy(),0,2)) #(3,256, 512)
#ax[1].imshow(encoded_mask,cmap='gray') #(256, 512)
ax[1].imshow(decoded_mask) #(256, 512, 3)
ax[2].imshow(decoded_ouput) #(256, 512, 3)
ax[0].axis('off')
ax[1].axis('off')
ax[2].axis('off')
ax[0].set_title('Input Image')
ax[1].set_title('Ground mask')
ax[2].set_title('Predicted mask')
plt.savefig('result.png',bbox_inches='tight')