In [5]:
import argparse

import torch
from torch.optim import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import DataLoader
from torchvision import transforms
from ImageDataset import *
from unet import *
from loss import *

## Unzip The Data  
Only if running on google colab


In [4]:
import zipfile
import os

# Specify the path to the uploaded zip file
zip_path = 'data.zip'
pred_path = 'predictions.zip'

# Specify the directory where you want to extract the contents
extract_path = '/data2/'
extract_pred= '/predictions2/'

# Create the extraction directory if it doesn't exist
os.makedirs(extract_path, exist_ok=True)

# Extract the contents of the zip file
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
    zip_ref.extractall('data')
os.makedirs(extract_pred, exist_ok=True)
with zipfile.ZipFile(pred_path, 'r') as zip_ref:
    zip_ref.extractall('predictions')


In [20]:
TRAIN_IMAGES = 'data/training/images/'
GROUNDTRUTH = 'data/training/groundtruth/'
TEST_IMAGES = 'data/test_set_images/'
FOREGROUND_TRESHOLD = 0.25
SPLIT_RATIO = 0.9
BATCH_SIZE = 10
EPOCHS = 30
LR = 1e-3
SEED = 0
WEIGHT_DECAY = 1e-3
WORKERS = 2

In [12]:
# Define device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Device:', device)
pin_memory = device == 'cuda'

Device: cpu


In [4]:
image_transform = transforms.Compose([
    transforms.ToTensor(),
])
mask_transform = transforms.Compose([
    transforms.ToTensor(),
])

In [5]:
dataset = ImagesDataset(
    img_dir=TRAIN_IMAGES,
    gt_dir=GROUNDTRUTH,
    image_transform=image_transform,
    mask_transform=mask_transform,
)

In [6]:
len(dataset)

100

In [7]:
image, mask = dataset[0]
print('Image size:', image.shape)
print('Mask size:', mask.shape)

Image size: torch.Size([3, 400, 400])
Mask size: torch.Size([1, 400, 400])


In [8]:
train_loader = DataLoader(
    dataset=dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=WORKERS,
    pin_memory=pin_memory,
)

In [9]:
model= UNet().to(device)

In [10]:
criterion = DiceLoss()
optimizer = Adam(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)

In [11]:
# Learning rate scheduler
lr_scheduler = ReduceLROnPlateau(
    optimizer=optimizer,
    mode='min',
    patience=5,
    verbose=True,
)

In [12]:
import torch
from torch.utils.data import DataLoader
from torch.optim import Adam
from torch.nn import BCEWithLogitsLoss
from tqdm import tqdm


# Set the model in training mode
model.train()

# Define the number of training epochs
epochs = 10  # You can adjust this as needed

# Training loop
for epoch in range(epochs):
    total_loss = 0.0

    # Iterate over the training data
    for data, target in tqdm(train_loader, desc=f'Epoch {epoch + 1}/{epochs}', unit='batch'):
        # Send the input to the device
        data, target = data.to(device), target.to(device)

        # Zero the parameter gradients
        optimizer.zero_grad()

        # Forward pass
        output = model(data)

        # Calculate the loss
        loss = criterion(output, target)

        # Backward pass and optimization
        loss.backward()
        optimizer.step()

        # Update the total loss
        total_loss += loss.item()

    # Average loss for the epoch
    average_loss = total_loss / len(train_loader)
    print(f"average loss: ", average_loss)

    # Adjust learning rate if a scheduler is provided
    if lr_scheduler is not None:
        lr_scheduler.step(average_loss)

torch.save(model.state_dict(), 'trained_model_10ep_30batch.pth')

# Save the trained model
#torch.save(model.state_dict(), 'trained_model.pth')

# Now, you can use the trained model for predictions
# For example, if you have a test DataLoader, you can do:
# model.eval()
# with torch.no_grad():
#     for test_data in test_data_loader:
#         test_data = test_data.to(device)
#         predictions = model(test_data)
#         # Process predictions as needed based on proba_threshold
#         # ...

# Note: This is a basic example, and you might need to adapt it based on your specific requirements and dataset structure.


Epoch 1/10: 100%|██████████| 10/10 [03:12<00:00, 19.28s/batch]


average loss:  0.6479327440261841


Epoch 2/10: 100%|██████████| 10/10 [03:05<00:00, 18.56s/batch]


average loss:  0.5524351000785828


Epoch 3/10: 100%|██████████| 10/10 [03:10<00:00, 19.05s/batch]


average loss:  0.5209612488746643


Epoch 4/10: 100%|██████████| 10/10 [03:09<00:00, 18.95s/batch]


average loss:  0.4906199038028717


Epoch 5/10: 100%|██████████| 10/10 [03:10<00:00, 19.06s/batch]


average loss:  0.45297694206237793


Epoch 6/10: 100%|██████████| 10/10 [03:11<00:00, 19.19s/batch]


average loss:  0.4200475513935089


Epoch 7/10: 100%|██████████| 10/10 [03:15<00:00, 19.57s/batch]


average loss:  0.4059355616569519


Epoch 8/10: 100%|██████████| 10/10 [03:22<00:00, 20.30s/batch]


average loss:  0.3777090311050415


Epoch 9/10: 100%|██████████| 10/10 [03:21<00:00, 20.15s/batch]


average loss:  0.3745973765850067


Epoch 10/10: 100%|██████████| 10/10 [03:28<00:00, 20.88s/batch]

average loss:  0.3589638650417328





## Testing and submission

In [22]:
test_image_transform = transforms.Compose([
    transforms.ToTensor(),
])

In [24]:
test_set = ImagesDataset(
    img_dir=TEST_IMAGES,
    image_transform=test_image_transform,
)

In [25]:
test_loader = DataLoader(
    dataset=test_set,
    num_workers=WORKERS,
    pin_memory=pin_memory,
)

In [6]:
def _get_pred_filename(lenth_loader, index: int) -> str:
    """Returns the filename of the prediction.

    Args:
        index (int): index of the image in the dataset.

    Returns:
        str: filename of the prediction.
    """
    if lenth_loader > 1000:
        return f'prediction_{index + 1:04d}.png'
    return f'prediction_{index + 1:03d}.png'

In [8]:
def _predict_labels(
    output: torch.Tensor,
    proba_threshold: float,
) -> torch.Tensor:
    """Predicts the labels for an output.

    Args:
        output (torch.Tensor): tensor output.
        proba_threshold (float): probability threshold.

    Returns:
        torch.Tensor: tensor of 0 and 1.
    """
    return (output > proba_threshold).type(torch.uint8)

In [9]:
def _save_mask(
    output: torch.Tensor,
    filename: str,
) -> None:
    """Saves the mask as image.

    Args:
        output (torch.Tensor): tensor output.
        filename (str): filename.
        clean (bool, optional): True to clean the prediction using
        postprocessing method. Defaults to True.
    """
    pred_array = torch.squeeze(output * 255).cpu().numpy()
    img = Image.fromarray(pred_array)
    img.save(filename)

In [13]:
predictions_path= 'predictions/'

In [14]:
prediction_filnames = list()

In [16]:
#Run this only if you load the model
model= UNet().to(device)
state_dict = torch.load('models/trained_model_100ep_10batch.pth', map_location=device)
model.load_state_dict(state_dict)

<All keys matched successfully>

In [26]:
# Set the model in evaluation mode
model.eval()

# Switch off autograd
with torch.no_grad():
    # Loop over the dataset
    for i, (data, target) in enumerate(test_loader):
        filename = _get_pred_filename(len(test_loader),i)
        print(f'Processing {filename}')

        # Send the input to the device
        data = data.to(device)
        if target.dim() != 1:
            target = target.to(device)

        # Make the predictions
        output = model(data)

        # Get labels
        output = _predict_labels(output, 0.25)

        # Save mask
        output_path = os.path.join(predictions_path, filename)
        _save_mask(output, output_path)
        prediction_filnames.append(output_path)

# Print a message after processing all images
print('Prediction completed.')


Processing prediction_001.png
Processing prediction_002.png
Processing prediction_003.png
Processing prediction_004.png
Processing prediction_005.png
Processing prediction_006.png
Processing prediction_007.png
Processing prediction_008.png
Processing prediction_009.png
Processing prediction_010.png
Processing prediction_011.png
Processing prediction_012.png
Processing prediction_013.png
Processing prediction_014.png
Processing prediction_015.png
Processing prediction_016.png
Processing prediction_017.png
Processing prediction_018.png
Processing prediction_019.png
Processing prediction_020.png
Processing prediction_021.png
Processing prediction_022.png
Processing prediction_023.png
Processing prediction_024.png
Processing prediction_025.png
Processing prediction_026.png
Processing prediction_027.png
Processing prediction_028.png
Processing prediction_029.png
Processing prediction_030.png
Processing prediction_031.png
Processing prediction_032.png
Processing prediction_033.png
Processing

In [27]:
from helper import *
masks_to_submission('submission.csv', *prediction_filnames)