In [1]:
import torch
import numpy as np
import segmentation_models_pytorch as smp

import torch.nn as nn
from torchmetrics.functional import accuracy

import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger

from torchlightning_module import torch_lightning_DataModule
from dataset_module import DatasetModule

metadata_path='./metadata/flair-one_TOY_metadata.json'
dataset_path='dataset'
#df_for_split_logic='img_ids.jsonl'

#ds=DatasetModule(metadata_path,dataset_path,train=False)

###
class SemanticSegmentationModel(pl.LightningModule):
    def __init__(self, hparams):
        super().__init__()
        for key in hparams.keys():
             self.hparams[key]=hparams[key]

        self.loss_fn = nn.CrossEntropyLoss()
        self.save_hyperparameters()
        
        self.model = smp.FPN( #modifiy in-channels and output. 
           encoder_name=self.hparams['encoder_name'], 
           classes=self.hparams['classes'], 
           activation=self.hparams['activation'], 
           encoder_weights=self.hparams['encoder_weights'],
            in_channels=self.hparams['in_channels'],

              )

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        images, labels = batch
        outputs = self.model(images)
        labels = labels.squeeze(1).long() 
        loss = self.loss_fn(outputs, labels)
        
        
        #outputs : output of model > logits /// labels > y ///images > x 
        self.log('train_loss', on_step=False, loss, on_epoch=True)
        return loss
    
    
#     def training_epoch_end(self, outputs):
#         avg_loss = torch.stack([x["loss"] for x in outputs]).mean()
#         self.log("avg_train_loss", avg_loss)
#         wandb.log({"avg_train_loss": avg_loss})


    def validation_step(self, batch, batch_idx):
        images, labels = batch
        outputs = self.model(images)
        labels = labels.squeeze(1).long()
        #print(f'{labels.shape} is shape of label in the val step method')
        #print(f'{outputs.shape} is shape of outputs in the val step method')

        loss = self.loss_fn(outputs, labels)
        # Log loss and accuracy
        self.log('val_loss', loss, on_step=False, on_epoch=True)
        
        return {'val_loss': loss}

#     def validation_epoch_end(self, outputs):
#         avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
#         logs = {'val_loss': avg_loss}
#         return {'val_loss': avg_loss, 'log': logs}

        # def validation_epoch_end(self, outputs):
        #     avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
        #     logs = {'val_loss': avg_loss}
        #     self.log('val_loss', avg_loss)



    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.hparams.lr)

    
# def lr_schedule(step):
#    lr = 0.001
#    if step < 10:
#        return lr
#    elif step < 20:
#        return lr / 2
#    else:
#        return lr / 4

#lr_scheduler = pl.callbacks.LearningRateScheduler(lr_schedule)

early_stopping = EarlyStopping(
    monitor='val_loss',
    patience=5,
    mode='min'
)


# Define your hyperparameters
hparams = {
    'encoder_name': 'efficientnet-b0', #resnet50
    'classes': 19,
    'activation': 'softmax',
    'encoder_weights': 'imagenet', #imagenet
    'lr': 0.001,#    'batch_size': 8,
    'num_workers': 8,
    'in_channels':5,

}

# Initialize your model
model = SemanticSegmentationModel(hparams)
data=torch_lightning_DataModule(batch_size=32,num_workers=8)

wandb_logger = WandbLogger(project='phase2_semantic_segmentation_initial_run')

trainer = pl.Trainer(max_epochs=10,# accelerator='gpu',
                     callbacks=[early_stopping],
                     logger=wandb_logger
                    ) #lr_scheduler

trainer.fit(model,data)

  from .autonotebook import tqdm as notebook_tqdm
[34m[1mwandb[0m: Currently logged in as: [33mmarkalsa[0m. Use [1m`wandb login --relogin`[0m to force relogin


GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name    | Type             | Params
---------------------------------------------
0 | loss_fn | CrossEntropyLoss | 0     
1 | model   | FPN              | 5.8 M 
---------------------------------------------
5.8 M     Trainable params
0         Non-trainable params
5.8 M     Total params
23.050    Total estimated model params size (MB)


Sanity Checking DataLoader 0:  50%|█████     | 1/2 [00:02<00:02,  2.86s/it]

  return self.activation(x)


                                                                           

  rank_zero_warn(


Epoch 0:  80%|████████  | 20/25 [02:54<00:43,  8.70s/it, loss=2.09, v_num=mzjq]
Validation: 0it [00:00, ?it/s][A
Validation:   0%|          | 0/5 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|          | 0/5 [00:00<?, ?it/s][A
Epoch 0:  84%|████████▍ | 21/25 [03:18<00:37,  9.46s/it, loss=2.09, v_num=mzjq]
Epoch 0:  88%|████████▊ | 22/25 [03:20<00:27,  9.13s/it, loss=2.09, v_num=mzjq]
Epoch 0:  92%|█████████▏| 23/25 [03:23<00:17,  8.84s/it, loss=2.09, v_num=mzjq]
Epoch 0:  96%|█████████▌| 24/25 [03:25<00:08,  8.56s/it, loss=2.09, v_num=mzjq]
Epoch 0: 100%|██████████| 25/25 [03:27<00:00,  8.32s/it, loss=2.09, v_num=mzjq]
Epoch 1:  80%|████████  | 20/25 [03:10<00:47,  9.51s/it, loss=2.03, v_num=mzjq]
Validation: 0it [00:00, ?it/s][A
Validation:   0%|          | 0/5 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|          | 0/5 [00:00<?, ?it/s][A
Epoch 1:  84%|████████▍ | 21/25 [03:44<00:42, 10.67s/it, loss=2.03, v_num=mzjq]
Epoch 1:  88%|████████▊ | 22/25 [03:46<00:30, 10.29s/i

`Trainer.fit` stopped: `max_epochs=10` reached.


Epoch 9: 100%|██████████| 25/25 [04:07<00:00,  9.90s/it, loss=2.03, v_num=mzjq]
