In [1]:
from utils import iou_score
# from torch.utils.tensorboard import SummaryWriter
import torch.nn as nn
import torch
import argparse
import pandas as pd
from tqdm import tqdm
import numpy as np
import wandb
from model import get_model
from dataloader import CorroSeg

Pyarrow will become a required dependency of pandas in the next major release of pandas (pandas 3.0),
(to allow more performant data types, such as the Arrow string type, and better interoperability with other libraries)
but was not found to be installed on your system.
If this would cause problems for you,
please provide us feedback at https://github.com/pandas-dev/pandas/issues/54466
        
  import pandas as pd
  from .autonotebook import tqdm as notebook_tqdm


In [6]:
def main(args):
    if(args.wandb):
        wandb.init(
            name=args.experiment_name,
            id=args.wandb_id,
            entity=args.wandb_entity,
            project="corroseg",
        )
        
        wandb.config = {
            "architecture":args.model_name,
            "epochs":args.num_epochs,
            "learning_rate":args.learning_rate,
        }
        
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = get_model(args.model_name, backbone=args.backbone).to(device)
    
    corro_seg = CorroSeg('data', 'y_train.csv', shuffle = True,
                 batch_size = args.batch_size, valid_ratio = args.valid_ratio, transform_img=None, transform_mask=None, 
                 transform_test=None, test_params={'batch_size': args.batch_size, 'shuffle': False})
    train_loader, val_loader, test_loader = corro_seg.get_loaders()

    # Loss function and optimizer definition
    criterion = nn.BCELoss()  # Binary cross-entropy loss
    optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay)

    for epoch in tqdm(range(args.num_epochs)):
        # Defreezing strategy
        if epoch % args.unfreeze_at_epoch == 0:
            layers_to_unfreeze = (epoch // args.unfreeze_at_epoch) * args.layers_to_unfreeze_each_time
            model.unfreeze_layers(layers_to_unfreeze)
        
        # Training phase
        model.train()
        train_loss = 0.0
        train_iou = 0.0
        
        for image, mask, well in tqdm(train_loader):
            mask = mask.view(-1, 1, 36, 36)
            optimizer.zero_grad()
            image = image.to(device)  # Move image to device
            mask = mask.to(device)  # Move mask to device
            image = image.unsqueeze(1)
            outputs = model(image.repeat(1, 3, 1, 1))
            print(image.repeat(1, 3, 1, 1).shape())
            loss = criterion(outputs, mask)
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item() * image.size(0)
            preds = outputs > args.threshold  # Apply threshold to get binary predictions
            train_iou += iou_score(preds, mask).item() * image.size(0)
        
        train_loss /= len(train_loader.dataset)
        train_iou /= len(train_loader.dataset)
        
        # Validation phase
        model.eval()
        val_loss = 0.0
        val_iou = 0.0
        
        with torch.no_grad():
            for image, mask, well in tqdm(val_loader):
                mask = mask.view(-1, 1, 36, 36)
                image = image.to(device)  # Move image to device
                mask = mask.to(device)  # Move mask to device
                image = image.unsqueeze(1)
                outputs = model(image.repeat(1, 3, 1, 1))
                outputs = outputs.detach()  # Detach outputs from the computation graph
                loss = criterion(outputs, mask)
                val_loss += loss.item() * image.size(0)
                preds = outputs > args.threshold  # Apply threshold to get binary predictions
                val_iou += iou_score(preds, mask).item() * image.size(0)
        
        val_loss /= len(val_loader.dataset)
        val_iou /= len(val_loader.dataset)
        
        # Logging to Weights and Biases
        if(args.wandb):
            wandb.log({'Train Loss': train_loss, 'Train IoU': train_iou,
                    'Validation Loss': val_loss, 'Validation IoU': val_iou}, step=epoch)
        
        print(f'Epoch {epoch+1}/{args.num_epochs}, Train Loss: {train_loss:.4f}, Train IoU: {train_iou:.4f}, Validation Loss: {val_loss:.4f}, Validation IoU: {val_iou:.4f}')
        
    # Testing phase
    model.eval()
    predicted_masks = []  # List to store predicted masks  
    with torch.no_grad():
        for image, _, _ in test_loader:  # Ignore the masks in the test loader
            
            # Forward pass
            image = image.to(device)  # Move image to device
            image = image.unsqueeze(1)
            output = model(image.repeat(1, 3, 1, 1)).detach()
            pred = output > args.threshold  # Apply threshold to get binary predictions
            pred = pred.cpu().numpy()
            
            # Flatten each 36x36 mask into a 1D array
            flattened_mask = pred.reshape(pred.shape[0], -1)
            
            # Convert predicted masks to numpy arrays
            predicted_masks.extend(flattened_mask)
    
    # Save predicted masks to a CSV file
    predicted_masks = np.array(predicted_masks)
    df = pd.DataFrame(predicted_masks)
    df.to_csv("predicted_masks.csv", index=False)
    
    print("Predicted masks saved to predicted_masks.csv")

In [11]:
args = argparse.Namespace(
        wandb=False,
        experiment_name='test4',
        output_dir='wandb',
        wandb_id=None,
        wandb_entity='lucasgascon',
        num_epochs=100,
        batch_size=64,
        valid_ratio=0.1,
        model_name='binary_segmentation',
        backbone='resnet50',
        learning_rate=2e-5,
        threshold=0.5,
        unfreeze_at_epoch=3,
        layers_to_unfreeze_each_time=1,
        weight_decay=0.01
    )

In [12]:
main(args)

Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /Users/lucasgascon/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth
  1%|          | 952k/97.8M [00:06<11:48, 143kB/s]  


KeyboardInterrupt: 