In [1]:
import torch
import torch.nn as nn
import argparse
import pandas as pd
import wandb
from tqdm import tqdm
from model import get_model
from dataloader import CorroSeg
from utils import iou_score,  RollTransform
from losses import SoftIoULoss, FocalLoss
from dataloader import CorroSeg, CorroSegDataset
import torchvision.transforms as transforms
import torch.nn.functional as F
import numpy as np 
import os
import datetime
from datetime import datetime
import easydict

  from .autonotebook import tqdm as notebook_tqdm


In [5]:
freq_corrosion = 7/100
def train(args):
    if args.experiment_name is None:
        args.experiment_name = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
            
    if(args.wandb):
        wandb.init(
            name=args.experiment_name,
            id=args.wandb_id,
            entity=args.wandb_entity,
            project="corroseg",
        )
        
        wandb.config = {
            "architecture":args.model_name,
            "epochs":args.num_epochs,
            "learning_rate":args.learning_rate,
        }
        
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = get_model(model_name=args.model_name, backbone_name=args.backbone).to(device)
    

    # Possible transforms: transforms.RandomHorizontalFlip(), transforms.RandomVerticalFlip(), t
        
    transform_img = [None,
        transforms.RandomHorizontalFlip(1),
        transforms.RandomVerticalFlip(1),RollTransform(),
        transforms.Compose([transforms.RandomVerticalFlip(1),transforms.RandomHorizontalFlip(1)]),]

 
    
    corro_seg = CorroSeg('data', 'y_train.csv', shuffle = True,
                 batch_size = args.batch_size, valid_ratio = args.valid_ratio, transform_img=transform_img,  
                 transform_test=None, test_params={'batch_size': args.batch_size, 'shuffle': False})
    train_loader, val_loader, test_loader = corro_seg.get_loaders()
    print("Data loaded")
    # print("Number of training images: ", len(train_loader.dataset))
    # print("Number of validation images: ", len(val_loader.dataset))
    # print("Number of test images: ", len(test_loader.dataset))

    # Loss function and optimizer definition
    if args.criterion == 'bce':
        criterion = nn.BCEWithLogitsLoss()
    elif args.criterion == 'iou':
        criterion = SoftIoULoss()
    elif args.criterion == 'focal':
        criterion = FocalLoss(args.gamma,1/freq_corrosion)
    optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay)


    for epoch in tqdm(range(args.num_epochs)):
        # Defreezing strategy
        if args.defreezing_strategy and (epoch % args.unfreeze_at_epoch == 0):
            layers_to_unfreeze = (epoch // args.unfreeze_at_epoch) * args.layers_to_unfreeze_each_time
            model.unfreeze_layers(layers_to_unfreeze)
        
        # Training phase
        model.train()
        train_loss = 0.0
        train_iou = 0.0
        
        for image, mask, well in tqdm(train_loader):
            if args.model_need_GRAY:
                image = torch.mean(image, dim=1, keepdim=True)
            mask = mask.view(-1, 1, 36, 36)
            optimizer.zero_grad()
            image = image.to(device)  # Move image to device
            mask = mask.to(device)  # Move mask to device
            outputs = model(image)
            loss = criterion(outputs, mask)
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item() * image.size(0)
            # Apply threshold to get binary predictions
            preds = (outputs - args.threshold).round()
            train_iou += iou_score(preds, mask).item() * image.size(0)
        
        train_loss /= len(train_loader.dataset)
        train_iou /= len(train_loader.dataset)
        
        # Validation phase
        model.eval()
        val_loss = 0.0
        val_iou = 0.0
        
        with torch.no_grad():
            for image, mask, well in tqdm(val_loader):
                if args.model_need_GRAY:
                    image = torch.mean(image, dim=1, keepdim=True)
                mask = mask.view(-1, 1, 36, 36)
                image = image.to(device)  # Move image to device
                mask = mask.to(device)  # Move mask to device
                outputs = model(image)
                outputs = outputs.detach()  # Detach outputs from the computation graph
                loss = criterion(outputs, mask)
                val_loss += loss.item() * image.size(0)
                # Apply threshold to get binary predictions
                preds = (outputs - args.threshold).round()
                val_iou += iou_score(preds, mask).item() * image.size(0)
        
        val_loss /= len(val_loader.dataset)
        val_iou /= len(val_loader.dataset)
        
        # Logging to Weights and Biases
        if(args.wandb):
            wandb.log({'Train Loss': train_loss, 'Train IoU': train_iou,
                    'Validation Loss': val_loss, 'Validation IoU': val_iou}, step=epoch)
        
        print(f'Epoch {epoch+1}/{args.num_epochs}, Train Loss: {train_loss:.4f}, Train IoU: {train_iou:.4f}, Validation Loss: {val_loss:.4f}, Validation IoU: {val_iou:.4f}')
        
    # Testing phase
    model.eval()
    predicted_masks = []  # List to store predicted masks
    with torch.no_grad():
        for image, _, _ in test_loader:  # Ignore the masks in the test loader
            if args.model_need_GRAY:
                image = torch.mean(image, dim=1, keepdim=True)

            image = image.to(device)
            output = model(image).detach()
            preds = output > args.threshold  # Apply threshold to get binary predictions
            preds = preds.int()

            # Check the unique values and values less than -100
            unique_values = torch.unique(image)
            # if len(unique_values) < 10 or torch.any(image < -100):
            if torch.any(image < -100):
                preds = torch.zeros_like(preds).int()  # Reset preds to zeros if conditions are met

            # Ensure consistent shape for all flattened masks
            flattened_mask = preds.cpu().numpy().reshape(-1, 36*36)  # Explicitly specify the flattened shape
            predicted_masks.extend(flattened_mask)

    # Save predicted masks to a CSV file
    predicted_masks = np.vstack(predicted_masks)  # Stack the list of arrays into a single 2D array
    df = pd.DataFrame(predicted_masks)

    files = [f.replace('.npy','') for f in os.listdir('data/processed/images_test')]
    df.index = files

    prediction_path = "data/predictions/submission_" + args.experiment_name + '.csv'
    df.to_csv(prediction_path, index=True)

    print("Predicted masks saved to predicted_masks.csv")


In [6]:
args = easydict.EasyDict({'num_epochs': 4, 'criterion': 'focal', 'batch_size': 64, 'valid_ratio': 0.1, 
          'model_name': 'cnn', 'backbone': 'efficientnet-v2-m', 'learning_rate': 2e-5, 
          'threshold': 0.5, 'defreezing_strategy': False, 'unfreeze_at_epoch': 0, 
          'layers_to_unfreeze_each_time': 100, 'weight_decay': 0.01, 'gamma': 3,
          'experiment_name': 'test', 'wandb': False, 'wandb_id': None, 'wandb_entity': None,
          'output_dir': 'wandb', 'model_need_GRAY':False})

In [7]:
train(args)

Data loaded


  0%|          | 0/4 [00:00<?, ?it/s]

: 