In [1]:
import os
import numpy as np
import time
import json
import shutil

import torch
import torch.nn as nn
import segmentation_models_pytorch as smp
# from src import data_loader

import matplotlib.pyplot as plt

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import os
import sys

# Add the parent directory of the 'playground' folder to the Python path
PROJECT_ROOT = os.path.abspath(os.path.join(os.getcwd(), ".."))
sys.path.insert(0, PROJECT_ROOT)

import torch
import torch.nn as nn
import torch.optim as optim
from transformers import SegformerForSemanticSegmentation
from utils.data_loader import get_train_data_loaders
import wandb
from tqdm import tqdm


In [12]:
class SegformerTrainer:
    def __init__(self, config):
        self.config = config
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        
        # Initialize wandb
        wandb.init(
            project="disaster-segmentation",
            config=config
        )
        
        # Setup data loaders
        self.train_loader, self.val_loader = get_train_data_loaders(
            root_dir=config['data_path'],
            validation_split=config['val_split'],
            batch_size=config['batch_size']
        )
        
        # Initialize model
        self.model = SegformerForSemanticSegmentation.from_pretrained(
            config['model_name'],
            num_labels=config['num_classes'],
            id2label={str(i): label for i, label in enumerate(config['class_names'])},
            label2id={label: str(i) for i, label in enumerate(config['class_names'])}
        )
        self.model = self.model.to(self.device)
        
        # Setup optimizer and scheduler
        self.optimizer = optim.AdamW(
            self.model.parameters(),
            lr=config['learning_rate'],
            weight_decay=config['weight_decay']
        )
        
        self.scheduler = optim.lr_scheduler.CosineAnnealingLR(
            self.optimizer,
            T_max=config['epochs'],
            eta_min=config['min_lr']
        )
        
        # Setup loss function with class weights
        class_weights = self._calculate_class_weights()
        self.criterion = nn.CrossEntropyLoss(weight=class_weights.to(self.device))

    def _calculate_class_weights(self):
        """Calculate class weights based on the class proportions from EDA"""
        class_proportions = {
            'background': 0.6242538624015307,
            'avalanche': 0.013283235435971551,
            'building_undamaged': 0.05180814333924117,
            'building_damaged': 0.03242294281633547,
            'cracks/fissure/subsidence': 0.0489093545134691,
            'debris/mud//rock flow': 0.06460793595122544,
            'fire/flare': 0.007107057487345816,
            'flood/water/river/sea': 0.0633561099778193,
            'ice_jam_flow': 0.01517923711913106,
            'lava_flow': 0.004472839026679955,
            'person': 0.000470474347778679,
            'pyroclastic_flow': 0.014284453359336742,
            'road/railway/bridge': 0.0537863123311819,
            'vehicle': 0.006058041892953064
        }
        
        # Convert proportions to weights (inverse frequency)
        weights = torch.tensor([
            1.0 / (prop + 1e-6)  # adding small epsilon to avoid division by zero
            for prop in class_proportions.values()
        ])
        
        # Normalize weights
        weights = weights / weights.sum()
        return weights

    def train_epoch(self):
        self.model.train()
        total_loss = 0
        
        pbar = tqdm(self.train_loader, desc='Training')
        for batch_idx, (images, masks) in enumerate(pbar):
            images = images.to(self.device)
            masks = masks.to(self.device)
            
            self.optimizer.zero_grad()
            
            outputs = self.model(pixel_values=images, labels=masks)
            loss = outputs.loss
            
            loss.backward()
            self.optimizer.step()
            
            total_loss += loss.item()
            pbar.set_postfix({'loss': f'{loss.item():.4f}'})
            
            wandb.log({
                'batch_loss': loss.item(),
                'learning_rate': self.optimizer.param_groups[0]['lr']
            })
        
        return total_loss / len(self.train_loader)

    def validate(self):
        self.model.eval()
        total_loss = 0
        
        with torch.no_grad():
            for images, masks in tqdm(self.val_loader, desc='Validation'):
                images = images.to(self.device)
                masks = masks.to(self.device)
                
                outputs = self.model(pixel_values=images, labels=masks)
                loss = outputs.loss
                
                total_loss += loss.item()
        
        return total_loss / len(self.val_loader)

    def train(self):
        best_val_loss = float('inf')
        
        for epoch in range(self.config['epochs']):
            print(f"\nEpoch {epoch+1}/{self.config['epochs']}")
            
            train_loss = self.train_epoch()
            val_loss = self.validate()
            
            self.scheduler.step()
            
            # Log metrics
            wandb.log({
                'epoch': epoch,
                'train_loss': train_loss,
                'val_loss': val_loss
            })
            
            print(f'Train Loss: {train_loss:.4f}')
            print(f'Val Loss: {val_loss:.4f}')
            
            # Save best model
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                self.save_model('best_model.pth')
            
            # Regular checkpoint
            if (epoch + 1) % self.config['save_every'] == 0:
                self.save_model(f'checkpoint_epoch_{epoch+1}.pth')

    def save_model(self, filename):
        save_path = os.path.join(self.config['save_dir'], filename)
        os.makedirs(self.config['save_dir'], exist_ok=True)
        
        torch.save({
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'scheduler_state_dict': self.scheduler.state_dict(),
            'config': self.config
        }, save_path)
        
        print(f'Model saved to {save_path}')

In [13]:
if __name__ == '__main__':
    config = {
        'data_path': 'D:/alib/LPCVC2023/data/LPCVC_Train_Updated/LPCVC_Train_Updated/LPCVC_Train_Updated',
        'model_name': 'nvidia/mit-b0',  # or 'nvidia/mit-b1', 'nvidia/mit-b2', etc.
        'num_classes': 14,
        'class_names': [
            'background', 'avalanche', 'building_undamaged', 'building_damaged',
            'cracks/fissure/subsidence', 'debris/mud//rock flow', 'fire/flare',
            'flood/water/river/sea', 'ice_jam_flow', 'lava_flow', 'person',
            'pyroclastic_flow', 'road/railway/bridge', 'vehicle'
        ],
        'batch_size': 8,
        'epochs': 50,
        'learning_rate': 1e-4,
        'min_lr': 1e-6,
        'weight_decay': 0.01,
        'val_split': 0.2,
        'save_dir': 'checkpoints',
        'save_every': 5
    }
    
    trainer = SegformerTrainer(config)
    trainer.train()

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
[34m[1mwandb[0m: Paste an API key from your profile and hit enter, or press ctrl+c to quit:

In [3]:
class SegformerTrainer:
    def __init__(self, config):
        self.config = config
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        # Setup data loaders
        self.train_loader, self.val_loader = self._setup_data_loaders(
            img_path=config['img_path'],
            gt_path=config['gt_path'],
            validation_split=config['val_split'],
            batch_size=config['batch_size']
        )

        # Initialize model
        self.model = SegformerForSemanticSegmentation.from_pretrained(
            config['model_name'],
            num_labels=config['num_classes'],
            id2label={str(i): label for i, label in enumerate(config['class_names'])},
            label2id={label: str(i) for i, label in enumerate(config['class_names'])}
        )
        self.model = self.model.to(self.device)

        # Setup optimizer and scheduler
        self.optimizer = optim.AdamW(
            self.model.parameters(),
            lr=config['learning_rate'],
            weight_decay=config['weight_decay']
        )

        self.scheduler = optim.lr_scheduler.CosineAnnealingLR(
            self.optimizer,
            T_max=config['epochs'],
            eta_min=config['min_lr']
        )

        # Setup loss function with class weights
        class_weights = self._calculate_class_weights()
        self.criterion = nn.CrossEntropyLoss(weight=class_weights.to(self.device))

    def _setup_data_loaders(self, img_path, gt_path, validation_split, batch_size):
        """
        Helper function to set up data loaders.
        """
        if not os.path.exists(img_path) or not os.path.exists(gt_path):
            raise FileNotFoundError(f"Paths {img_path} or {gt_path} do not exist.")

        return get_train_data_loaders(
            img_dir=img_path,
            gt_dir=gt_path,
            validation_split=validation_split,
            batch_size=batch_size
        )

    def _calculate_class_weights(self):
        """
        Helper function to calculate class weights for the loss function.
        Modify this function if you have a specific way to compute class weights.
        """
        # Example: Uniform weights (1 for each class).
        num_classes = self.config['num_classes']
        return torch.ones(num_classes)

    def train_epoch(self):
        self.model.train()
        total_loss = 0
        
        pbar = tqdm(self.train_loader, desc='Training')
        for batch_idx, (images, masks) in enumerate(pbar):
            images = images.to(self.device)
            masks = masks.to(self.device)
            
            self.optimizer.zero_grad()
            
            outputs = self.model(pixel_values=images, labels=masks)
            loss = outputs.loss
            
            loss.backward()
            self.optimizer.step()
            
            total_loss += loss.item()
            pbar.set_postfix({'loss': f'{loss.item():.4f}'})
        
        return total_loss / len(self.train_loader)

    def validate(self):
        self.model.eval()
        total_loss = 0
        
        with torch.no_grad():
            for images, masks in tqdm(self.val_loader, desc='Validation'):
                images = images.to(self.device)
                masks = masks.to(self.device)
                
                outputs = self.model(pixel_values=images, labels=masks)
                loss = outputs.loss
                
                total_loss += loss.item()
        
        return total_loss / len(self.val_loader)

    def train(self):
        best_val_loss = float('inf')
        
        for epoch in range(self.config['epochs']):
            print(f"\nEpoch {epoch+1}/{self.config['epochs']}")
            
            train_loss = self.train_epoch()
            val_loss = self.validate()
            
            self.scheduler.step()
            
            print(f'Train Loss: {train_loss:.4f}')
            print(f'Val Loss: {val_loss:.4f}')
            
            # Save best model
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                self.save_model('best_model.pth')
            
            # Regular checkpoint
            if (epoch + 1) % self.config['save_every'] == 0:
                self.save_model(f'checkpoint_epoch_{epoch+1}.pth')

    def save_model(self, filename):
        save_path = os.path.join(self.config['save_dir'], filename)
        os.makedirs(self.config['save_dir'], exist_ok=True)
        
        torch.save({
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'scheduler_state_dict': self.scheduler.state_dict(),
            'config': self.config
        }, save_path)
        
        print(f'Model saved to {save_path}')


In [6]:
if __name__ == '__main__':
    config = {
    'img_path': 'D:/alib/LPCVC2023/data/LPCVC_Train_Updated/LPCVC_Train_Updated/LPCVC_Train_Updated/IMG/train',
    'gt_path': 'D:/alib/LPCVC2023/data/LPCVC_Train_Updated/LPCVC_Train_Updated/LPCVC_Train_Updated/GT_Updated/train',
    'model_name': 'nvidia/mit-b0',
    'num_classes': 14,
    'class_names': [
        'background', 'avalanche', 'building_undamaged', 'building_damaged',
        'cracks/fissure/subsidence', 'debris/mud//rock flow', 'fire/flare',
        'flood/water/river/sea', 'ice_jam_flow', 'lava_flow', 'person',
        'pyroclastic_flow', 'road/railway/bridge', 'vehicle'
    ],
    'batch_size': 8,
    'epochs': 5,
    'learning_rate': 1e-4,
    'min_lr': 1e-6,
    'weight_decay': 0.01,
    'val_split': 0.2,
    'save_dir': 'checkpoints',
    'save_every': 5
    }

    
    trainer = SegformerTrainer(config)
    trainer.train()

Some weights of SegformerForSemanticSegmentation were not initialized from the model checkpoint at nvidia/mit-b0 and are newly initialized: ['decode_head.batch_norm.bias', 'decode_head.batch_norm.num_batches_tracked', 'decode_head.batch_norm.running_mean', 'decode_head.batch_norm.running_var', 'decode_head.batch_norm.weight', 'decode_head.classifier.bias', 'decode_head.classifier.weight', 'decode_head.linear_c.0.proj.bias', 'decode_head.linear_c.0.proj.weight', 'decode_head.linear_c.1.proj.bias', 'decode_head.linear_c.1.proj.weight', 'decode_head.linear_c.2.proj.bias', 'decode_head.linear_c.2.proj.weight', 'decode_head.linear_c.3.proj.bias', 'decode_head.linear_c.3.proj.weight', 'decode_head.linear_fuse.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.



Epoch 1/5


Training: 100%|██████████| 102/102 [00:39<00:00,  2.61it/s, loss=1.3026]
Validation: 100%|██████████| 26/26 [00:15<00:00,  1.72it/s]


Train Loss: 1.8558
Val Loss: 1.2771
Model saved to checkpoints\best_model.pth

Epoch 2/5


Training: 100%|██████████| 102/102 [00:34<00:00,  2.94it/s, loss=0.7973]
Validation: 100%|██████████| 26/26 [00:10<00:00,  2.53it/s]


Train Loss: 1.1427
Val Loss: 0.9483
Model saved to checkpoints\best_model.pth

Epoch 3/5


Training: 100%|██████████| 102/102 [00:33<00:00,  3.06it/s, loss=0.7515]
Validation: 100%|██████████| 26/26 [00:08<00:00,  3.13it/s]


Train Loss: 0.8800
Val Loss: 0.7674
Model saved to checkpoints\best_model.pth

Epoch 4/5


Training: 100%|██████████| 102/102 [00:33<00:00,  3.09it/s, loss=0.8382]
Validation: 100%|██████████| 26/26 [00:07<00:00,  3.38it/s]


Train Loss: 0.7715
Val Loss: 0.7112
Model saved to checkpoints\best_model.pth

Epoch 5/5


Training: 100%|██████████| 102/102 [00:32<00:00,  3.10it/s, loss=0.6184]
Validation: 100%|██████████| 26/26 [00:07<00:00,  3.45it/s]

Train Loss: 0.7165
Val Loss: 0.6996
Model saved to checkpoints\best_model.pth
Model saved to checkpoints\checkpoint_epoch_5.pth



