In [1]:
from src.datasets.data_module import CloudCoverDataModule
from pathlib import Path
from src.models.unet import LightningUNet
from src.training.trainer import train
from src.testing.tester import test

### Dataloader

In [2]:
data_module = CloudCoverDataModule(
    train_X_folder_path=Path("../data/final/public/train_features/"),
    train_y_folder_path=Path("../data/final/public/train_labels/"),
    test_X_folder_path=Path("../data/final/private/test_features/"),
    test_y_folder_path=Path("../data/final/private/test_labels/"),
    train_batch_size=4,
    val_batch_size=8,
    test_batch_size=8,
    val_size=0.2,
    random_state=42
)

In [3]:
data_module.prepare_data()

In [4]:
data_module.setup(stage="fit")
data_module.setup(stage="test")

In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models
import pytorch_lightning as pl
import torch.optim as optim
import torchmetrics

# MODELE DEEPLAB
class ASPP(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ASPP, self).__init__()
        self.dilations = [1, 6, 12, 18]
        self.aspp_blocks = nn.ModuleList()

        for dilation in self.dilations:
            self.aspp_blocks.append(
                nn.Conv2d(in_channels, out_channels, 3, padding=dilation, dilation=dilation)
            )
        self.output_conv = nn.Conv2d(len(self.dilations) * out_channels, out_channels, 1)

    def forward(self, x):
        aspp_outputs = [block(x) for block in self.aspp_blocks]
        x = torch.cat(aspp_outputs, dim=1)
        x = self.output_conv(x)
        return x

class DeepLabV3(nn.Module):
    def __init__(self, num_classes, in_channels=4):
        super(DeepLabV3, self).__init__()
        # Load a pretrained ResNet model
        self.backbone = models.resnet50(pretrained=True)

        original_first_layer = self.backbone.conv1
        # Create a new Conv2d layer with 4 input channels instead of 3
        self.backbone.conv1 = nn.Conv2d(in_channels,  # Change from 3 to 4
                                        original_first_layer.out_channels, 
                                        kernel_size=original_first_layer.kernel_size, 
                                        stride=original_first_layer.stride, 
                                        padding=original_first_layer.padding, 
                                        bias=False)
        
        # Copy the weights from the original first layer to the new layer
        with torch.no_grad():
            self.backbone.conv1.weight[:, :3] = original_first_layer.weight
            # Initialize the extra channel weights with zeros or another preferred method
            self.backbone.conv1.weight[:, 3] = torch.zeros_like(self.backbone.conv1.weight[:, 0])
        
        # Replace the fully connected layer of ResNet with ASPP
        in_channels = 2048  # Depends on the ResNet model
        self.aspp = ASPP(in_channels, 256)

        # Final convolutional layers
        self.conv1 = nn.Conv2d(256, 256, 3, padding=1)
        self.conv2 = nn.Conv2d(256, num_classes, 1)

    def forward(self, x):
        # Extract features from the backbone
        x = self.backbone.conv1(x)
        x = self.backbone.bn1(x)
        x = self.backbone.relu(x)
        x = self.backbone.maxpool(x)

        x = self.backbone.layer1(x)
        x = self.backbone.layer2(x)
        x = self.backbone.layer3(x)
        x = self.backbone.layer4(x)

        # Pass features through ASPP
        x = self.aspp(x)

        # Additional convolutional layers
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)

        # Upsample the output to the size of the input
        x = F.interpolate(x, scale_factor=16, mode='bilinear', align_corners=False)
        return x

class LightningDeeplab(pl.LightningModule):
    def __init__(self, n_channels, n_classes, bilinear=True, learning_rate=1e-3):
        super().__init__()
        self.model = DeepLabV3(n_classes, n_channels)
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear
        self.learning_rate = learning_rate

        self.save_hyperparameters()

        self.train_jaccard = torchmetrics.JaccardIndex(num_classes=n_classes, task='binary')
        self.train_accuracy = torchmetrics.Accuracy(num_classes=n_classes, task='binary', average='macro')
        
        self.val_jaccard = torchmetrics.JaccardIndex(num_classes=n_classes, task='binary')
        self.val_accuracy = torchmetrics.Accuracy(num_classes=n_classes, task='binary', average='macro')
        
        self.test_jaccard = torchmetrics.JaccardIndex(num_classes=n_classes, task='binary')
        self.test_accuracy = torchmetrics.Accuracy(num_classes=n_classes, task='binary', average='macro')

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        inputs, target = batch
        y_hat = self(inputs)
        y_hat = F.interpolate(y_hat, size=target.size()[1:], mode='bilinear', align_corners=False)
        predicted_labels = torch.argmax(y_hat, dim=1)

        loss = self.compute_loss(y_hat, target)
        self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)


        self.train_jaccard(predicted_labels, target)
        self.log('train_jaccard', self.train_jaccard, on_step=True, on_epoch=True, prog_bar=True, logger=True)

        self.train_accuracy(predicted_labels, target)
        self.log('train_accuracy', self.train_accuracy, on_step=True, on_epoch=True, prog_bar=True, logger=True)

        return loss
    
    def compute_loss(self, y_hat, y):
        return F.cross_entropy(y_hat, y)

    def validation_step(self, batch, batch_idx, dataloader_idx=0):
        inputs, target = batch
        y_hat = self(inputs)
        y_hat = F.interpolate(y_hat, size=target.size()[1:], mode='bilinear', align_corners=False)
        predicted_labels = torch.argmax(y_hat, dim=1)

        loss = self.compute_loss(y_hat, target)
        self.log('val_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)

        self.val_jaccard(predicted_labels, target)
        self.log('val_jaccard', self.val_jaccard, on_step=True, on_epoch=True, prog_bar=True, logger=True)

        self.val_accuracy(predicted_labels, target)
        self.log('val_accuracy', self.val_accuracy, on_step=True, on_epoch=True, prog_bar=True, logger=True)

    def test_step(self, batch, batch_idx):
        inputs, target = batch
        y_hat = self(inputs)
        y_hat = F.interpolate(y_hat, size=target.size()[1:], mode='bilinear', align_corners=False)
        predicted_labels = torch.argmax(y_hat, dim=1)

        loss = self.compute_loss(y_hat, target)
        self.log('test_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)

        self.test_jaccard(predicted_labels, target)
        self.log('test_jaccard', self.test_jaccard, on_step=True, on_epoch=True, prog_bar=True, logger=True)

        self.test_accuracy(predicted_labels, target)
        self.log('test_accuracy', self.test_accuracy, on_step=True, on_epoch=True, prog_bar=True, logger=True)

    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=self.learning_rate)
        return optimizer

### Train

In [6]:
print(torch.cuda.is_available()) 

False


In [10]:
deeplab = train(
    model=LightningDeeplab(n_channels=4, n_classes=2, bilinear=True, learning_rate=0.001),
    run_name="deeplab",
    model_version=0,
    data_module=data_module,
    max_epochs=2,
    patience=5
)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name           | Type               | Params
------------------------------------------------------
0 | model          | DeepLabV3          | 45.3 M
1 | train_jaccard  | BinaryJaccardIndex | 0     
2 | train_accuracy | BinaryAccuracy     | 0     
3 | val_jaccard    | BinaryJaccardIndex | 0     
4 | val_accuracy   | BinaryAccuracy     | 0     
5 | test_jaccard   | BinaryJaccardIndex | 0     
6 | test_accuracy  | BinaryAccuracy     | 0     
------------------------------------------------------
45.3 M    Trainable params
0         Non-trainable params
45.3 M    Total params
181.154   Total estimated model params size (MB)


Sanity Checking: |                                                                               | 0/? [00:00<…

Training: |                                                                                      | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

`Trainer.fit` stopped: `max_epochs=2` reached.


### Test  
Replace the checkpoint path with the best checkpoint from the training.

In [11]:
#deeplab = LightningDeeplab.load_from_checkpoint(checkpoint_path='./logs/deeplab/version_0/checkpoints/deeplab-epoch=01-val_loss=0.67.ckpt', n_channels=4, n_classes=2)

In [12]:
test(
    model=deeplab,
    run_name="deeplab",
    model_version=0,
    data_module=data_module
)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
C:\Users\ultav\AppData\Local\Programs\Python\Python311\Lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:441: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.


Testing: |                                                                                       | 0/? [00:00<…

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
   test_accuracy_epoch      0.6116666197776794
   test_jaccard_epoch       0.6116637587547302
     test_loss_epoch         0.676939845085144
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
