In [1]:
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl


import sys
sys.path.append('../Utilities/')

from tqdm.notebook import tqdm

import importlib
import data_utils
importlib.reload(data_utils)

## Import MDS from sklearn
from sklearn.manifold import MDS
from sklearn.metrics import accuracy_score, f1_score
mds = MDS(n_components=1, random_state=0, normalized_stress='auto')

In [2]:
class UNet1D(nn.Module):
    def __init__(self, in_channels, out_channels, depth=2, num_layers=2):
        super(UNet1D, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.num_layers = num_layers
        self.depth = depth
        self.encoder = nn.ModuleList()
        self.decoder = nn.ModuleList()
        self.num_start_filters = 32

        self._create_unet(self.in_channels, self.num_start_filters)
        self.bottleneck = nn.Sequential(
            nn.Conv1d(self.num_start_filters * 2 ** (self.depth - 1), 2 * self.num_start_filters * 2 ** (self.depth - 1), kernel_size=1, padding=0),
            nn.ReLU()
        )
        self.logits = nn.Conv1d(self.num_start_filters, self.out_channels, 1, 1)


    def _create_encoder_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv1d(in_channels, out_channels, kernel_size=5, padding=2),
            nn.ReLU()
        )

    def _create_decoder_block(self, in_channels, out_channels):
        return nn.ModuleList([nn.ConvTranspose1d(in_channels, in_channels//2, kernel_size=2, stride=2),
            nn.Conv1d(in_channels, out_channels, kernel_size=5, padding=2),
            nn.ReLU()])

    def _create_unet(self, in_channels, out_channels):
        for _ in range(self.depth):
            self.encoder.append(self._create_encoder_block(in_channels, out_channels))
            in_channels, out_channels = out_channels, out_channels*2

        out_channels = in_channels
        in_channels = in_channels * 2
        for _ in range(self.depth):
            self.decoder.append(self._create_decoder_block(in_channels, out_channels))
            in_channels, out_channels = out_channels, out_channels//2

    def forward(self, x):
        encoded = []
        for enc in self.encoder:
            x = enc(x)
            encoded.append(x)
            x = nn.MaxPool1d(kernel_size=2, stride=2)(x)

        x = self.bottleneck(x)  # Bottleneck layer

        for dec in self.decoder:
            ## Adding input with encoder concatenation
            enc_output = encoded.pop()
            x = dec[0](x)
            ## Pad the decoder output to match the encoder output
            diff = enc_output.shape[2] - x.shape[2]
            x = F.pad(x, (diff // 2, diff - diff // 2))
            x = torch.cat((enc_output, x), dim=1)
            x = dec[1](x)
            x = dec[2](x)
        ## Add softmax to logits
        # x = self.softmax(x)

        return self.logits(x)

input_channels = 6 
output_channels = 2
depth = 4
num_layers = 2

model = UNet1D(input_channels, output_channels, depth, num_layers)
print(model)

UNet1D(
  (encoder): ModuleList(
    (0): Sequential(
      (0): Conv1d(6, 32, kernel_size=(5,), stride=(1,), padding=(2,))
      (1): ReLU()
    )
    (1): Sequential(
      (0): Conv1d(32, 64, kernel_size=(5,), stride=(1,), padding=(2,))
      (1): ReLU()
    )
    (2): Sequential(
      (0): Conv1d(64, 128, kernel_size=(5,), stride=(1,), padding=(2,))
      (1): ReLU()
    )
    (3): Sequential(
      (0): Conv1d(128, 256, kernel_size=(5,), stride=(1,), padding=(2,))
      (1): ReLU()
    )
  )
  (decoder): ModuleList(
    (0): ModuleList(
      (0): ConvTranspose1d(512, 256, kernel_size=(2,), stride=(2,))
      (1): Conv1d(512, 256, kernel_size=(5,), stride=(1,), padding=(2,))
      (2): ReLU()
    )
    (1): ModuleList(
      (0): ConvTranspose1d(256, 128, kernel_size=(2,), stride=(2,))
      (1): Conv1d(256, 128, kernel_size=(5,), stride=(1,), padding=(2,))
      (2): ReLU()
    )
    (2): ModuleList(
      (0): ConvTranspose1d(128, 64, kernel_size=(2,), stride=(2,))
      (1): C

In [3]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters())

In [4]:
# Generate synthetic data
num_superpixels = 300
num_features = 6
synthetic_data = np.random.rand(num_superpixels, num_features)
synthetic_data = torch.tensor(synthetic_data, dtype=torch.float32)

#Reshape
synthetic_data = synthetic_data.unsqueeze(0).transpose(1, 2)

# Pass the synthetic data through the U-Net model
with torch.no_grad():
    output = model(synthetic_data)

print("Input shape:", synthetic_data.shape)
print("Output shape:", output.shape)

Input shape: torch.Size([1, 6, 300])
Output shape: torch.Size([1, 2, 300])


### Creating LightningModule 

In [5]:
class DiceBCELoss(nn.Module):
    def __init__(self, weight=None, size_average=True):
        super(DiceBCELoss, self).__init__()

    def forward(self, inputs, targets, smooth=1):
        
        #comment out if your model contains a sigmoid or equivalent activation layer
        inputs = F.sigmoid(inputs)       
        
        #flatten label and prediction tensors
        inputs = inputs.view(-1)
        targets = targets.view(-1)
        
        intersection = (inputs * targets).sum()                            
        dice_loss = 1 - (2.*intersection + smooth)/(inputs.sum() + targets.sum() + smooth)  
        # BCE = F.binary_cross_entropy(inputs, targets, reduction='mean')
        Dice_BCE = dice_loss
        
        return Dice_BCE

In [6]:
class CloudSegmentationModel(pl.LightningModule):
    def __init__(self, depth=3):
        super(CloudSegmentationModel, self).__init__()
        self.unet = UNet1D(in_channels=6, out_channels=1, depth=depth)
        
    def forward(self, x):
        return nn.Sigmoid()(self.unet(x))

    def accuracy_score(self, y_true, y_pred):
        y_true = y_true.cpu().detach().numpy()
        y_pred = y_pred.cpu().detach().numpy()
        y_pred = np.where(y_pred > 0.5, 1, 0)
        return accuracy_score(y_true, y_pred)

    def training_step(self, batch, batch_idx):
        superpixel, label = batch
        output = self(superpixel)
        loss = nn.BCELoss()(output, label)

        # Calculate accuracy
        predicted = torch.round(output)
        correct = (predicted == label).sum().item()
        total = label.size(0) * label.size(1) * label.size(2)
        accuracy = correct / total

        self.log('train_loss', loss, on_step=False, on_epoch=True)
        self.log('train_accuracy', accuracy, on_step=False, on_epoch=True)
        return loss

    def validation_step(self, batch, batch_idx):
        superpixel, label = batch
        output = self(superpixel)
        loss = nn.BCELoss()(output, label)

        # Calculate accuracy
        predicted = torch.round(output)
        correct = (predicted == label).sum().item()
        total = label.size(0) * label.size(1) * label.size(2)
        accuracy = correct / total

        self.log('val_loss', loss, on_step=False, on_epoch=True)
        self.log('val_accuracy', accuracy, on_step=False, on_epoch=True)


    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.001)

In [7]:
model = CloudSegmentationModel()
print("Parameters: ",count_parameters(model))

with torch.no_grad():
    output = model(synthetic_data)

print("Input shape:", synthetic_data.shape)
print("Output shape:", output.shape)

Parameters:  386945
Input shape: torch.Size([1, 6, 300])
Output shape: torch.Size([1, 1, 300])


### Creating Dataset

In [28]:
patches,mask = data_utils.get_patch(path_to_folders_images = "../Dataset/Natural_False_Color/", path_to_folders_labels = "../Dataset/Entire_scene_gts/")

  dataset = DatasetReader(path, driver=driver, sharing=sharing, **kwargs)


In [29]:
X = []
y = []
for i,j in tqdm(list(zip(patches,mask))):
    try:
        output = data_utils.convert_image_array_to_slic_with_properties(i,j,n_segments=300) 

        ## Getting the X and y arrays
        X_array = np.array([list(list(i.values())[0]) + list(i.values())[1:] for i in output[1]])
        y_array = np.array(output[2])

        ## Normalizing the X_array columwise
        X_array[:,0] = X_array[:,0]/255
        X_array[:,1] = X_array[:,1]/255
        X_array[:,2] = X_array[:,2]/255
        X_array[:,3] = X_array[:,3]/512
        X_array[:,4] = X_array[:,4]/512
        X_array[:,5] = X_array[:,5]/1000



        ## Pad the X_array with -1 and y_array with 0 upto 300
        X_array = np.pad(X_array,((0,300-X_array.shape[0]),(0,0)),mode='constant',constant_values=-1)
        y_array = np.pad(y_array,(0,300-y_array.shape[0]),mode='constant',constant_values=0)

        ## Ordering
        ordering = mds.fit_transform(X_array[:,3:5]).reshape(-1)
        X_array = X_array[ordering.argsort()]
        y_array = y_array[ordering.argsort()].reshape(-1,1)

        ## Appending
        X.append(X_array)
        y.append(y_array)
        
    except KeyboardInterrupt:
        break
    except:
        pass

  0%|          | 0/1472 [00:00<?, ?it/s]

In [8]:
# np.save('../Dataset/X.npy', np.array(X))
# np.save('../Dataset/Y.npy', np.array(y))

X = np.load('../Dataset/X.npy')
y = np.load('../Dataset/Y.npy')

In [9]:
class CustomDataset(Dataset):
    def __init__(self, X, y):
        self.X = X
        self.y = y

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        x_item = torch.tensor(self.X[idx], dtype=torch.float32).T
        y_item = torch.tensor(self.y[idx], dtype=torch.float32).T
        return x_item, y_item

def create_dataloader(X, y, batch_size=32, shuffle=True):
    dataset = CustomDataset(X, y)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)
    return dataloader

In [10]:
## Divide X and Y into train and test sets
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

In [11]:
train_loader = create_dataloader(X_train,y_train,batch_size=64,shuffle=True)
test_loader = create_dataloader(X_test,y_test,batch_size=64,shuffle=True)

## Training with PL

In [12]:
import matplotlib.pyplot as plt
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning.callbacks.progress import TQDMProgressBar
from pytorch_lightning.loggers import TensorBoardLogger

In [24]:
# Initialize the model and trainer
segmentationModel = CloudSegmentationModel(depth=3)

In [25]:
logger = TensorBoardLogger('lightning_logs/', name='sgd_tagger')

In [26]:
checkpoint_callback = ModelCheckpoint(
    dirpath = './/model_checkpt/',
    filename = 'best-checkpoint',
    save_top_k=1,
    verbose = True,
    monitor='val_loss',
    mode='min'
)

In [27]:
trainer = Trainer(
    logger = logger,
    gpus=1 if torch.cuda.is_available() else None,
    max_epochs=100,
    callbacks=[EarlyStopping(monitor='val_loss', patience=20), checkpoint_callback]
)

# Train the model using the trainer
trainer.fit(segmentationModel, train_loader, test_loader)


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name | Type   | Params
--------------------------------
0 | unet | UNet1D | 386 K 
--------------------------------
386 K     Trainable params
0         Non-trainable params
386 K     Total params
1.548     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Epoch 0: 100%|██████████| 34/34 [00:00<00:00, 34.01it/s, loss=0.618, v_num=4]

Epoch 0, global step 27: 'val_loss' reached 0.58780 (best 0.58780), saving model to 'D:\\Projects\\ComputerVision_CloudSegmentation\\Superpixel-UNET\\model_checkpt\\best-checkpoint-v5.ckpt' as top 1


Epoch 1: 100%|██████████| 34/34 [00:00<00:00, 46.01it/s, loss=0.514, v_num=4]

Epoch 1, global step 54: 'val_loss' reached 0.51053 (best 0.51053), saving model to 'D:\\Projects\\ComputerVision_CloudSegmentation\\Superpixel-UNET\\model_checkpt\\best-checkpoint-v5.ckpt' as top 1


Epoch 2: 100%|██████████| 34/34 [00:00<00:00, 41.64it/s, loss=0.49, v_num=4] 

Epoch 2, global step 81: 'val_loss' was not in top 1


Epoch 3: 100%|██████████| 34/34 [00:00<00:00, 49.78it/s, loss=0.483, v_num=4]

Epoch 3, global step 108: 'val_loss' reached 0.48426 (best 0.48426), saving model to 'D:\\Projects\\ComputerVision_CloudSegmentation\\Superpixel-UNET\\model_checkpt\\best-checkpoint-v5.ckpt' as top 1


Epoch 4: 100%|██████████| 34/34 [00:00<00:00, 47.42it/s, loss=0.466, v_num=4]

Epoch 4, global step 135: 'val_loss' was not in top 1


Epoch 5: 100%|██████████| 34/34 [00:00<00:00, 49.10it/s, loss=0.468, v_num=4]

Epoch 5, global step 162: 'val_loss' reached 0.48265 (best 0.48265), saving model to 'D:\\Projects\\ComputerVision_CloudSegmentation\\Superpixel-UNET\\model_checkpt\\best-checkpoint-v5.ckpt' as top 1


Epoch 6: 100%|██████████| 34/34 [00:00<00:00, 49.57it/s, loss=0.452, v_num=4]

Epoch 6, global step 189: 'val_loss' reached 0.46208 (best 0.46208), saving model to 'D:\\Projects\\ComputerVision_CloudSegmentation\\Superpixel-UNET\\model_checkpt\\best-checkpoint-v5.ckpt' as top 1


Epoch 7: 100%|██████████| 34/34 [00:00<00:00, 49.06it/s, loss=0.426, v_num=4]

Epoch 7, global step 216: 'val_loss' reached 0.42145 (best 0.42145), saving model to 'D:\\Projects\\ComputerVision_CloudSegmentation\\Superpixel-UNET\\model_checkpt\\best-checkpoint-v5.ckpt' as top 1


Epoch 8: 100%|██████████| 34/34 [00:00<00:00, 50.41it/s, loss=0.392, v_num=4]

Epoch 8, global step 243: 'val_loss' reached 0.37788 (best 0.37788), saving model to 'D:\\Projects\\ComputerVision_CloudSegmentation\\Superpixel-UNET\\model_checkpt\\best-checkpoint-v5.ckpt' as top 1


Epoch 9: 100%|██████████| 34/34 [00:00<00:00, 51.13it/s, loss=0.374, v_num=4]

Epoch 9, global step 270: 'val_loss' reached 0.37673 (best 0.37673), saving model to 'D:\\Projects\\ComputerVision_CloudSegmentation\\Superpixel-UNET\\model_checkpt\\best-checkpoint-v5.ckpt' as top 1


Epoch 10: 100%|██████████| 34/34 [00:00<00:00, 47.85it/s, loss=0.359, v_num=4]

Epoch 10, global step 297: 'val_loss' reached 0.36042 (best 0.36042), saving model to 'D:\\Projects\\ComputerVision_CloudSegmentation\\Superpixel-UNET\\model_checkpt\\best-checkpoint-v5.ckpt' as top 1


Epoch 11: 100%|██████████| 34/34 [00:00<00:00, 50.03it/s, loss=0.353, v_num=4]

Epoch 11, global step 324: 'val_loss' reached 0.35486 (best 0.35486), saving model to 'D:\\Projects\\ComputerVision_CloudSegmentation\\Superpixel-UNET\\model_checkpt\\best-checkpoint-v5.ckpt' as top 1


Epoch 12: 100%|██████████| 34/34 [00:00<00:00, 50.44it/s, loss=0.351, v_num=4]

Epoch 12, global step 351: 'val_loss' was not in top 1


Epoch 13: 100%|██████████| 34/34 [00:00<00:00, 48.74it/s, loss=0.357, v_num=4]

Epoch 13, global step 378: 'val_loss' reached 0.35386 (best 0.35386), saving model to 'D:\\Projects\\ComputerVision_CloudSegmentation\\Superpixel-UNET\\model_checkpt\\best-checkpoint-v5.ckpt' as top 1


Epoch 14: 100%|██████████| 34/34 [00:00<00:00, 50.28it/s, loss=0.343, v_num=4]

Epoch 14, global step 405: 'val_loss' reached 0.35183 (best 0.35183), saving model to 'D:\\Projects\\ComputerVision_CloudSegmentation\\Superpixel-UNET\\model_checkpt\\best-checkpoint-v5.ckpt' as top 1


Epoch 15: 100%|██████████| 34/34 [00:00<00:00, 49.28it/s, loss=0.335, v_num=4]

Epoch 15, global step 432: 'val_loss' reached 0.34039 (best 0.34039), saving model to 'D:\\Projects\\ComputerVision_CloudSegmentation\\Superpixel-UNET\\model_checkpt\\best-checkpoint-v5.ckpt' as top 1


Epoch 16: 100%|██████████| 34/34 [00:00<00:00, 48.21it/s, loss=0.354, v_num=4]

Epoch 16, global step 459: 'val_loss' was not in top 1


Epoch 17: 100%|██████████| 34/34 [00:00<00:00, 47.82it/s, loss=0.343, v_num=4]

Epoch 17, global step 486: 'val_loss' was not in top 1


Epoch 18: 100%|██████████| 34/34 [00:00<00:00, 49.35it/s, loss=0.349, v_num=4]

Epoch 18, global step 513: 'val_loss' was not in top 1


Epoch 19: 100%|██████████| 34/34 [00:00<00:00, 48.98it/s, loss=0.334, v_num=4]

Epoch 19, global step 540: 'val_loss' reached 0.33909 (best 0.33909), saving model to 'D:\\Projects\\ComputerVision_CloudSegmentation\\Superpixel-UNET\\model_checkpt\\best-checkpoint-v5.ckpt' as top 1


Epoch 20: 100%|██████████| 34/34 [00:00<00:00, 47.59it/s, loss=0.334, v_num=4]

Epoch 20, global step 567: 'val_loss' was not in top 1


Epoch 21: 100%|██████████| 34/34 [00:00<00:00, 47.67it/s, loss=0.334, v_num=4]

Epoch 21, global step 594: 'val_loss' was not in top 1


Epoch 22: 100%|██████████| 34/34 [00:00<00:00, 48.50it/s, loss=0.326, v_num=4]

Epoch 22, global step 621: 'val_loss' reached 0.33821 (best 0.33821), saving model to 'D:\\Projects\\ComputerVision_CloudSegmentation\\Superpixel-UNET\\model_checkpt\\best-checkpoint-v5.ckpt' as top 1


Epoch 23: 100%|██████████| 34/34 [00:00<00:00, 49.74it/s, loss=0.334, v_num=4]

Epoch 23, global step 648: 'val_loss' was not in top 1


Epoch 24: 100%|██████████| 34/34 [00:00<00:00, 50.75it/s, loss=0.33, v_num=4] 

Epoch 24, global step 675: 'val_loss' was not in top 1


Epoch 25: 100%|██████████| 34/34 [00:00<00:00, 49.67it/s, loss=0.328, v_num=4]

Epoch 25, global step 702: 'val_loss' was not in top 1


Epoch 26: 100%|██████████| 34/34 [00:00<00:00, 49.97it/s, loss=0.333, v_num=4]

Epoch 26, global step 729: 'val_loss' was not in top 1


Epoch 27: 100%|██████████| 34/34 [00:00<00:00, 47.39it/s, loss=0.33, v_num=4] 

Epoch 27, global step 756: 'val_loss' was not in top 1


Epoch 28: 100%|██████████| 34/34 [00:00<00:00, 47.45it/s, loss=0.324, v_num=4]

Epoch 28, global step 783: 'val_loss' was not in top 1


Epoch 29: 100%|██████████| 34/34 [00:00<00:00, 50.24it/s, loss=0.334, v_num=4]

Epoch 29, global step 810: 'val_loss' was not in top 1


Epoch 30: 100%|██████████| 34/34 [00:00<00:00, 48.36it/s, loss=0.329, v_num=4]

Epoch 30, global step 837: 'val_loss' reached 0.33400 (best 0.33400), saving model to 'D:\\Projects\\ComputerVision_CloudSegmentation\\Superpixel-UNET\\model_checkpt\\best-checkpoint-v5.ckpt' as top 1


Epoch 31: 100%|██████████| 34/34 [00:00<00:00, 50.95it/s, loss=0.318, v_num=4]

Epoch 31, global step 864: 'val_loss' was not in top 1


Epoch 32: 100%|██████████| 34/34 [00:00<00:00, 50.45it/s, loss=0.325, v_num=4]

Epoch 32, global step 891: 'val_loss' was not in top 1


Epoch 33: 100%|██████████| 34/34 [00:00<00:00, 48.09it/s, loss=0.328, v_num=4]

Epoch 33, global step 918: 'val_loss' reached 0.33119 (best 0.33119), saving model to 'D:\\Projects\\ComputerVision_CloudSegmentation\\Superpixel-UNET\\model_checkpt\\best-checkpoint-v5.ckpt' as top 1


Epoch 34: 100%|██████████| 34/34 [00:00<00:00, 49.83it/s, loss=0.319, v_num=4]

Epoch 34, global step 945: 'val_loss' was not in top 1


Epoch 35: 100%|██████████| 34/34 [00:00<00:00, 49.17it/s, loss=0.325, v_num=4]

Epoch 35, global step 972: 'val_loss' was not in top 1


Epoch 36: 100%|██████████| 34/34 [00:00<00:00, 50.26it/s, loss=0.321, v_num=4]

Epoch 36, global step 999: 'val_loss' was not in top 1


Epoch 37: 100%|██████████| 34/34 [00:00<00:00, 49.38it/s, loss=0.328, v_num=4]

Epoch 37, global step 1026: 'val_loss' was not in top 1


Epoch 38: 100%|██████████| 34/34 [00:00<00:00, 48.67it/s, loss=0.326, v_num=4]

Epoch 38, global step 1053: 'val_loss' was not in top 1


Epoch 39: 100%|██████████| 34/34 [00:00<00:00, 50.37it/s, loss=0.318, v_num=4]

Epoch 39, global step 1080: 'val_loss' was not in top 1


Epoch 40: 100%|██████████| 34/34 [00:00<00:00, 47.97it/s, loss=0.33, v_num=4] 

Epoch 40, global step 1107: 'val_loss' was not in top 1


Epoch 41: 100%|██████████| 34/34 [00:00<00:00, 49.42it/s, loss=0.32, v_num=4] 

Epoch 41, global step 1134: 'val_loss' was not in top 1


Epoch 42: 100%|██████████| 34/34 [00:00<00:00, 49.68it/s, loss=0.32, v_num=4] 

Epoch 42, global step 1161: 'val_loss' was not in top 1


Epoch 43: 100%|██████████| 34/34 [00:00<00:00, 49.45it/s, loss=0.316, v_num=4]

Epoch 43, global step 1188: 'val_loss' was not in top 1


Epoch 44: 100%|██████████| 34/34 [00:00<00:00, 46.64it/s, loss=0.315, v_num=4]

Epoch 44, global step 1215: 'val_loss' was not in top 1


Epoch 45: 100%|██████████| 34/34 [00:00<00:00, 49.34it/s, loss=0.313, v_num=4]

Epoch 45, global step 1242: 'val_loss' was not in top 1


Epoch 46: 100%|██████████| 34/34 [00:00<00:00, 47.44it/s, loss=0.312, v_num=4]

Epoch 46, global step 1269: 'val_loss' was not in top 1


Epoch 47: 100%|██████████| 34/34 [00:00<00:00, 44.87it/s, loss=0.313, v_num=4]

Epoch 47, global step 1296: 'val_loss' reached 0.33109 (best 0.33109), saving model to 'D:\\Projects\\ComputerVision_CloudSegmentation\\Superpixel-UNET\\model_checkpt\\best-checkpoint-v5.ckpt' as top 1


Epoch 48: 100%|██████████| 34/34 [00:00<00:00, 50.17it/s, loss=0.323, v_num=4]

Epoch 48, global step 1323: 'val_loss' was not in top 1


Epoch 49: 100%|██████████| 34/34 [00:00<00:00, 47.96it/s, loss=0.322, v_num=4]

Epoch 49, global step 1350: 'val_loss' was not in top 1


Epoch 50: 100%|██████████| 34/34 [00:00<00:00, 47.89it/s, loss=0.324, v_num=4]

Epoch 50, global step 1377: 'val_loss' was not in top 1


Epoch 51: 100%|██████████| 34/34 [00:00<00:00, 49.38it/s, loss=0.31, v_num=4] 

Epoch 51, global step 1404: 'val_loss' was not in top 1


Epoch 52: 100%|██████████| 34/34 [00:00<00:00, 50.00it/s, loss=0.319, v_num=4]

Epoch 52, global step 1431: 'val_loss' was not in top 1


Epoch 53: 100%|██████████| 34/34 [00:00<00:00, 49.60it/s, loss=0.317, v_num=4]

Epoch 53, global step 1458: 'val_loss' was not in top 1


Epoch 54: 100%|██████████| 34/34 [00:00<00:00, 46.75it/s, loss=0.319, v_num=4]

Epoch 54, global step 1485: 'val_loss' was not in top 1


Epoch 55: 100%|██████████| 34/34 [00:00<00:00, 48.88it/s, loss=0.312, v_num=4]

Epoch 55, global step 1512: 'val_loss' was not in top 1


Epoch 56: 100%|██████████| 34/34 [00:00<00:00, 50.22it/s, loss=0.315, v_num=4]

Epoch 56, global step 1539: 'val_loss' was not in top 1


Epoch 57: 100%|██████████| 34/34 [00:00<00:00, 48.26it/s, loss=0.311, v_num=4]

Epoch 57, global step 1566: 'val_loss' was not in top 1


Epoch 58: 100%|██████████| 34/34 [00:00<00:00, 50.21it/s, loss=0.315, v_num=4]

Epoch 58, global step 1593: 'val_loss' was not in top 1


Epoch 59: 100%|██████████| 34/34 [00:00<00:00, 50.56it/s, loss=0.308, v_num=4]

Epoch 59, global step 1620: 'val_loss' was not in top 1


Epoch 60: 100%|██████████| 34/34 [00:00<00:00, 50.59it/s, loss=0.309, v_num=4]

Epoch 60, global step 1647: 'val_loss' was not in top 1


Epoch 61: 100%|██████████| 34/34 [00:00<00:00, 47.08it/s, loss=0.308, v_num=4]

Epoch 61, global step 1674: 'val_loss' was not in top 1


Epoch 62: 100%|██████████| 34/34 [00:00<00:00, 50.44it/s, loss=0.307, v_num=4]

Epoch 62, global step 1701: 'val_loss' was not in top 1


Epoch 63: 100%|██████████| 34/34 [00:00<00:00, 49.28it/s, loss=0.315, v_num=4]

Epoch 63, global step 1728: 'val_loss' was not in top 1


Epoch 64: 100%|██████████| 34/34 [00:00<00:00, 48.05it/s, loss=0.318, v_num=4]

Epoch 64, global step 1755: 'val_loss' was not in top 1


Epoch 65: 100%|██████████| 34/34 [00:00<00:00, 51.13it/s, loss=0.313, v_num=4]

Epoch 65, global step 1782: 'val_loss' was not in top 1


Epoch 66: 100%|██████████| 34/34 [00:00<00:00, 50.15it/s, loss=0.309, v_num=4]

Epoch 66, global step 1809: 'val_loss' was not in top 1


Epoch 67: 100%|██████████| 34/34 [00:00<00:00, 47.09it/s, loss=0.308, v_num=4]

Epoch 67, global step 1836: 'val_loss' was not in top 1


Epoch 67: 100%|██████████| 34/34 [00:00<00:00, 46.64it/s, loss=0.308, v_num=4]


## Training the module

In [103]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

segmentationModel = CloudSegmentationModel().to(device)

# train_loader = None # Train loader for our dataset
# test_loader = None # Test loader for our dataset

# loss function and optimizer
criterion = nn.BCELoss()
optimizer = torch.optim.Adam(segmentationModel.parameters(), lr=0.001)

# Training loop
num_epochs = 100
segmentationModel.train() 

for epoch in tqdm(range(num_epochs)):
    running_loss = 0

    segmentationModel.train()
    for superpixel, label in train_loader:
        superpixel = superpixel.to(device)
        label = label.to(device)

        # Forward pass
        output = segmentationModel(superpixel)
        loss = criterion(output, label)

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

        ## Calculate accuracy
        predicted = torch.round(output)
        correct = (predicted == label).sum().item()
        total = label.size(0) * label.size(1) * label.size(2)
        accuracy = correct / total


    # Print epoch statistics
    epoch_loss = running_loss / len(train_loader)
    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.4f}, Accuracy: {accuracy:.4f}')

    # Evaluation
    segmentationModel.eval()
    test_loss = 0

    with torch.no_grad():
        for superpixel, label in test_loader:
            superpixel = superpixel.to(device)
            label = label.to(device)

            output = segmentationModel(superpixel)
            test_loss += criterion(output, label).item()

            ## Calculate accuracy
            predicted = torch.round(output)
            correct = (predicted == label).sum().item()
            total = label.size(0) * label.size(1) * label.size(2)
            accuracy = correct / total


    test_loss /= len(test_loader)
    ## Print loss and accuracy
    print(f'Test Loss: {test_loss:.4f}, Accuracy: {accuracy:.4f}')


  0%|          | 0/100 [00:00<?, ?it/s]

Epoch [1/100], Loss: 0.6364, Accuracy: 0.6186
Test Loss: 0.5749, Accuracy: 0.6813
Epoch [2/100], Loss: 0.5424, Accuracy: 0.8075
Test Loss: 0.5319, Accuracy: 0.7332
Epoch [3/100], Loss: 0.5089, Accuracy: 0.7436
Test Loss: 0.5077, Accuracy: 0.7726
Epoch [4/100], Loss: 0.4980, Accuracy: 0.7191
Test Loss: 0.5190, Accuracy: 0.7041
Epoch [5/100], Loss: 0.4874, Accuracy: 0.7807
Test Loss: 0.5052, Accuracy: 0.7741
Epoch [6/100], Loss: 0.4865, Accuracy: 0.8059
Test Loss: 0.4932, Accuracy: 0.7659
Epoch [7/100], Loss: 0.4706, Accuracy: 0.7999
Test Loss: 0.4940, Accuracy: 0.6825
Epoch [8/100], Loss: 0.4623, Accuracy: 0.7862
Test Loss: 0.4570, Accuracy: 0.7324
Epoch [9/100], Loss: 0.4311, Accuracy: 0.7923
Test Loss: 0.4193, Accuracy: 0.8353
Epoch [10/100], Loss: 0.4135, Accuracy: 0.8057
Test Loss: 0.4009, Accuracy: 0.8276
Epoch [11/100], Loss: 0.3908, Accuracy: 0.8448
Test Loss: 0.3932, Accuracy: 0.8185
Epoch [12/100], Loss: 0.3822, Accuracy: 0.8141
Test Loss: 0.4353, Accuracy: 0.8000
Epoch [13/100