# Logging utils

In [None]:
import logging
import os

from tqdm import tqdm
from tqdm.contrib.logging import logging_redirect_tqdm

! rm -rf logs

log_dir = 'logs'
if not os.path.exists(log_dir):
    os.makedirs(log_dir)

LOGGER = logging.getLogger('notebook')
LOGGER.setLevel(logging.DEBUG)

log_file_path = os.path.join(log_dir, 'run.log')
file_handler = logging.FileHandler(log_file_path)

formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')

file_handler.setFormatter(formatter)

LOGGER.addHandler(file_handler)

console_handler = logging.StreamHandler()
console_handler.setFormatter(formatter)
LOGGER.addHandler(console_handler)

logging_redirect_tqdm()

LOGGER.debug("This is a debug message")
LOGGER.info("This is an info message")
LOGGER.warning("This is a warning message")
LOGGER.error("This is an error message")
LOGGER.critical("This is a critical message")



In [None]:
import time

def timeit(func):
    def wrapper(*args, **kwargs):
        start_time = time.time()
        result = func(*args, **kwargs)
        end_time = time.time()
        elapsed_time = end_time - start_time
        LOGGER.info(f"Function '{func.__name__}' executed in {elapsed_time:.4f} seconds")
        return result
    return wrapper

# Download dataset

In [None]:
! pip install huggingface_hub
! pip install datasets

In [None]:
from huggingface_hub import login
from datasets import load_dataset
from huggingface_hub import snapshot_download
import sys

dataset_id="RayanAi/Main_teeth_dataset"
local_dataset_dir = "./Main_teeth_dataset"  

os.makedirs(local_dataset_dir, exist_ok=True)

with open(os.devnull, 'w') as fnull:
    original_stdout = sys.stdout
    try:
        sys.stdout = fnull
        snapshot_download(repo_id=dataset_id, local_dir=local_dataset_dir, repo_type="dataset")
    finally:
        sys.stdout = original_stdout

LOGGER.info("Dataset downloaded completely.")

total_size = 0
for dirpath, dirnames, filenames in os.walk(local_dataset_dir):
    for f in filenames:
        fp = os.path.join(dirpath, f)
        total_size += os.path.getsize(fp)

LOGGER.info(f"Total size of downloaded files: {total_size / (1024 * 1024):.2f} MB")

dataset_abs_path = os.path.abspath(local_dataset_dir)
LOGGER.info(f"Dataset has been saved at: [{dataset_abs_path}]")

In [None]:
!unzip -q ./Main_teeth_dataset/Main_teeth_dataset.zip -d ./Main_teeth_dataset/

# Dataset

In [None]:
! pip install albumentations

In [None]:
from PIL import Image

def load_image(image_path):
    try:
        img = Image.open(image_path).convert('L')
        return img
    except Exception as e:
        LOGGER.error(f"Error loading image {image_path}: {e}")
        return None


In [None]:
from typing import List, Tuple
import random

import torch
from torch.utils.data import Dataset
import albumentations as A
from albumentations.pytorch import ToTensorV2
from tqdm import tqdm
import numpy as np


class TeethDataset(Dataset):
    def __init__(self):
        self.loaded_data: List[Tuple[np.ndarray, np.ndarray]] = []
        self.X: List[torch.Tensor] = []
        self.y: List[torch.Tensor] = []

    def load(self, data: List[Tuple[str, str]]):
        for image_path, label_path in data:
            image = np.array(load_image(image_path))  # Should return a numpy array
            label = np.array(load_image(label_path))  # Should return a numpy array
            self.loaded_data.append((image, label))

    def augment(self, generation_per_sample: int, augmentation: A.Compose):
        LOGGER.info("Augmenting data ...")
        augmented_data = []
        for image, label in tqdm(self.loaded_data):
            for _ in range(generation_per_sample):
                augmented = augmentation(image=image, mask=label)
                augmented_image = augmented['image']
                augmented_label = augmented['mask']
                augmented_data.append((augmented_image, augmented_label))
        self.loaded_data.extend(augmented_data)
        random.seed(68)
        random.shuffle(self.loaded_data)

    def transform(self, transformation: A.Compose):
        for image, label in self.loaded_data:
            transformed = transformation(image=image, mask=label)
            transformed_image = transformed['image']
            transformed_label = transformed['mask']

            binary_mask = transformed_label.unsqueeze(2)>0
            binary_mask = binary_mask.permute(2, 0, 1).float()
            
            
            self.X.append(transformed_image)
            self.y.append(binary_mask)

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

# Load images path

In [None]:
import glob

labels_root = os.path.join(local_dataset_dir, 'labels')
images_root = os.path.join(local_dataset_dir, 'images')


images_path = glob.glob(os.path.join(images_root, '*.png'))
labels_path = glob.glob(os.path.join(labels_root, '*.png'))


len(labels_path), len(images_path)

# Remove noisy images

In [None]:
from typing import List

import numpy as np


def calculate_entropy(image):
    image_array = np.array(image)
    hist, _ = np.histogram(image_array, bins=256, range=(0, 256), density=True)
    entropy = -np.sum(hist * np.log2(hist + 1e-7))
    return entropy


def get_noisy_images(images_path: List[str], threshold: float):
    noisy_images = []
    for file_path in images_path:
        img = load_image(file_path)
        
        if img is None:
            continue
    
        img = img.convert("L")
    
        entropy = calculate_entropy(img)
    
        if entropy >= threshold:
            noisy_images.append(file_path)
            
    return noisy_images

In [None]:
LOGGER.info(len(images_path), len(labels_path))

random.shuffle(labels_path)

noisy_images = get_noisy_images(labels_path, 0.2)
LOGGER.info(len(noisy_images))
labels_path = [label_path for label_path in labels_path
               if label_path not in noisy_images]
images_path = [os.path.join(images_root, os.path.basename(file_path)) for file_path in labels_path] 

LOGGER.info(len(images_path), len(labels_path))


In [None]:
if torch.cuda.is_available():
    LOGGER.info(f"GPU: {torch.cuda.get_device_name(0)} is available.")
else:
    LOGGER.info("No GPU available. Training will run on CPU.")

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
LOGGER.info(f'Device is {DEVICE}')

# Exploration on Augmentation

In [None]:


def prepare_image(img):
    if img.dtype != np.uint8:
        img = np.clip(img, 0, 255).astype(np.uint8)
    return img

def prepare_label(lbl):
    lbl = (lbl > 0).astype(np.uint8) * 255
    return lbl


def plot_augmentations(original_image, augmented_image, original_label, augmented_label):
    fig, axs = plt.subplots(2, 2, figsize=(12, 8))

    axs[0, 0].imshow(original_image, cmap='gray')
    axs[0, 0].set_title('Original Image')
    axs[0, 0].axis('off')

    axs[0, 1].imshow(augmented_image, cmap='gray')
    axs[0, 1].set_title('Augmented Image')
    axs[0, 1].axis('off')

    axs[1, 0].imshow(original_label, cmap='gray')
    axs[1, 0].set_title('Original Label')
    axs[1, 0].axis('off')

    axs[1, 1].imshow(augmented_label, cmap='gray')
    axs[1, 1].set_title('Augmented Label')
    axs[1, 1].axis('off')

    plt.tight_layout()
    plt.show()

    

AUGMENTATION_SAMPLE = A.Compose([
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.5),
    A.RandomRotate90(p=0.5),
    A.Rotate(limit=15, p=0.5, border_mode=0),
    A.ShiftScaleRotate(
        shift_limit=0.0625, scale_limit=0.1, rotate_limit=0,
        p=0.5, border_mode=0
    ),
    A.OneOf([
        A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2),
        A.GaussianBlur(blur_limit=(3, 7)),
        A.GaussNoise(var_limit=(10.0, 50.0)),
    ], p=0.5),
],
additional_targets={'mask': 'mask'})



image = np.array(load_image(images_path[0]))
label = np.array(load_image(labels_path[0]))



augmented = AUGMENTATION_SAMPLE(image=image, mask=label)
augmented_image = augmented['image']
augmented_label = augmented['mask']


image_plot = prepare_image(image)
augmented_image_plot = prepare_image(augmented_image)
label_plot = prepare_label(label)
augmented_label_plot = prepare_label(augmented_label)    

plot_augmentations(image_plot, augmented_image_plot, label_plot, augmented_label_plot)

# Train

In [None]:
import string

def generate_random_string(length=8, use_digits=True, use_lowercase=True, use_uppercase=True, use_special=False):
    char_pool = ''
    if use_digits:
        char_pool += string.digits
    if use_lowercase:
        char_pool += string.ascii_lowercase
    if use_uppercase:
        char_pool += string.ascii_uppercase
    if use_special:
        char_pool += string.punctuation

    if not char_pool:
        raise ValueError("At least one character type must be enabled.")

    return ''.join(random.choice(char_pool) for _ in range(length))

In [None]:


def dice_score(preds, targets, epsilon=1e-6):
    preds = torch.sigmoid(preds)
    preds = (preds > 0.5).float()
    intersection = (preds * targets).sum(dim=(1, 2, 3))
    dice = (2. * intersection + epsilon) / (preds.sum(dim=(1, 2, 3)) + targets.sum(dim=(1, 2, 3)) + epsilon)
    return dice.mean().item()

In [None]:
import torch
from typing import Tuple
from torch.utils.data import DataLoader
from tqdm import tqdm
import os
import copy
import gc

@timeit
def train_epoch(model: torch.nn.Module, train_loader: DataLoader, optimizer, criterion) -> Tuple[float, float]:
    model.train()
    total_train_loss = 0.0
    total_dice_score = 0.0

    for inputs, masks in tqdm(train_loader, desc="Training"):
        inputs = inputs.to(DEVICE)
        masks = masks.to(DEVICE)

        optimizer.zero_grad()
        outputs = model(inputs)

        loss = criterion(outputs, masks)
        loss.backward()
        optimizer.step()

        total_train_loss += loss.item() * inputs.size(0)
        dice = dice_score(outputs, masks)
        total_dice_score += dice * inputs.size(0)

        del inputs, masks, outputs, loss, dice

    epoch_train_loss = total_train_loss / len(train_loader.dataset)
    epoch_dice_score = total_dice_score / len(train_loader.dataset)

    return epoch_train_loss, epoch_dice_score

@timeit
def val_epoch(model: torch.nn.Module, val_loader: DataLoader, criterion) -> Tuple[float, float]:
    model.eval()
    total_val_loss = 0.0
    total_dice_score = 0.0

    with torch.no_grad():
        for inputs, masks in tqdm(val_loader, desc="Validation"):
            inputs = inputs.to(DEVICE)
            masks = masks.to(DEVICE)

            outputs = model(inputs)
            loss = criterion(outputs, masks)
            total_val_loss += loss.item() * inputs.size(0)
            dice = dice_score(outputs, masks)
            total_dice_score += dice * inputs.size(0)

            del inputs, masks, outputs, loss, dice

    epoch_val_loss = total_val_loss / len(val_loader.dataset)
    epoch_dice_score = total_dice_score / len(val_loader.dataset)

    return epoch_val_loss, epoch_dice_score

@timeit
def test_model(model: torch.nn.Module, test_loader: DataLoader, criterion) -> Tuple[float, float]:
    model.eval()
    total_test_loss = 0.0
    total_dice_score = 0.0

    with torch.no_grad():
        for inputs, masks in tqdm(test_loader, desc="Testing"):
            inputs = inputs.to(DEVICE)
            masks = masks.to(DEVICE)

            outputs = model(inputs)
            loss = criterion(outputs, masks)
            total_test_loss += loss.item() * inputs.size(0)
            dice = dice_score(outputs, masks)
            total_dice_score += dice * inputs.size(0)

            del inputs, masks, outputs, loss, dice

    epoch_test_loss = total_test_loss / len(test_loader.dataset)
    epoch_test_dice_score = total_dice_score / len(test_loader.dataset)

    return epoch_test_loss, epoch_test_dice_score


In [None]:

@timeit
def train(
    model: torch.nn.Module,
    train_dataset: torch.utils.data.Dataset,
    val_dataset: torch.utils.data.Dataset,
    test_dataset: torch.utils.data.Dataset,
    stable_epochs_count: int,
    stable_dice_score_distance: float,
    best_models_dir: str,
    stable_models_dir: str,
    epochs: int = 100
) -> Tuple[list, list, list, list, str, str, float, float, float, float]:
    
    train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=4, shuffle=False)

    criterion = nn.BCEWithLogitsLoss()#nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

    model = model.to(DEVICE)

    train_losses = []
    val_losses = []
    train_dice_scores = []
    val_dice_scores = []

    best_val_dice_score = 0
    best_stable_dice_score = 0

    best_model_path = ""
    best_stable_model_path = ""

    models_history_in_memory = []

    for epoch in range(epochs):
        epoch_train_loss, train_dice = train_epoch(model, train_loader, optimizer, criterion)
        train_losses.append(epoch_train_loss)
        train_dice_scores.append(train_dice)

        epoch_val_loss, val_dice = val_epoch(model, val_loader, criterion)
        val_losses.append(epoch_val_loss)
        val_dice_scores.append(val_dice)

        models_history_in_memory.append(copy.deepcopy(model))
        if len(models_history_in_memory) > stable_epochs_count:
            models_history_in_memory.pop(0)

        if val_dice > best_val_dice_score:
            best_val_dice_score = val_dice
            model_path = os.path.join(
                best_models_dir,
                f'dice_{round(val_dice, 4)}_epoch_{epoch+1}_{generate_random_string(3)}.pth'
            )
            torch.save(model.state_dict(), model_path)
            best_model_path = model_path
            LOGGER.info(f'Model saved: {model_path}')

        if epoch >= 50 and len(val_dice_scores) >= stable_epochs_count:
            recent_dice = val_dice_scores[-stable_epochs_count:]
            max_diff = max(recent_dice) - min(recent_dice)
            if max_diff <= stable_dice_score_distance:
                best_idx = recent_dice.index(max(recent_dice))
                stable_model = copy.deepcopy(models_history_in_memory[best_idx])
                stable_dice = recent_dice[best_idx]
                stable_model_path = os.path.join(
                    stable_models_dir,
                    f'stable_dice_{round(stable_dice, 4)}_epoch_{epoch+1}_{generate_random_string(3)}.pth'
                )
                torch.save(stable_model.state_dict(), stable_model_path)
                best_stable_dice_score = max(best_stable_dice_score, stable_dice)
                best_stable_model_path = stable_model_path
                LOGGER.info(f'Stable model saved: {stable_model_path}')

        LOGGER.info(
            f'Epoch {epoch+1}/{epochs} | '
            f'Train Loss: {epoch_train_loss:.4f} | Train Dice: {train_dice:.4f} | '
            f'Val Loss: {epoch_val_loss:.4f} | Val Dice: {val_dice:.4f}'
        )

        torch.cuda.empty_cache()
        gc.collect()

    best_model = copy.deepcopy(model)
    if best_model_path:
        best_model.load_state_dict(torch.load(best_model_path))
    best_model = best_model.to(DEVICE)
    best_model_test_loss, best_model_test_dice_score = test_model(best_model, test_loader, criterion)

    if best_stable_model_path:
        stable_model = copy.deepcopy(model)
        stable_model.load_state_dict(torch.load(best_stable_model_path))
        stable_model = stable_model.to(DEVICE)
        stable_model_test_loss, stable_model_test_dice_score = test_model(stable_model, test_loader, criterion)
    else:
        stable_model_test_loss, stable_model_test_dice_score = None, None

    return (
        train_losses,
        train_dice_scores,
        val_losses,
        val_dice_scores,
        best_model_path,
        best_stable_model_path,
        best_model_test_loss,
        best_model_test_dice_score,
        stable_model_test_loss,
        stable_model_test_dice_score
    )

# Plot utils

In [None]:
import matplotlib.pyplot as plt

def plot_metrics(train_metric, val_metric, save_path, metric_name="Loss", color_train='b', color_val='r'):
    epochs = range(1, len(train_metric) + 1)

    plt.figure()
    plt.plot(epochs, train_metric, color=color_train, label=f'Training {metric_name}')
    plt.plot(epochs, val_metric, color=color_val, label=f'Validation {metric_name}')

    plt.xlabel('Epochs')
    plt.ylabel(metric_name)
    plt.title(f'Training and Validation {metric_name} Over Epochs')
    plt.legend()
    plt.grid(True)

    plt.savefig(save_path)
    plt.show()


# Transformation

In [None]:
from albumentations.pytorch import ToTensorV2

TRAIN_TRANSFORMATION =  A.Compose([
    A.Normalize(mean=(0.485,), std=(0.229,), max_pixel_value=255.0),
    ToTensorV2(),
])

VALIDATION_TRANSFORMATION =  A.Compose([
    A.Normalize(mean=(0.485,), std=(0.229,), max_pixel_value=255.0),
    ToTensorV2(),
])


TEST_TRANSFORMATION =  A.Compose([
    A.Normalize(mean=(0.485,), std=(0.229,), max_pixel_value=255.0),
    ToTensorV2(),
])

# Augmentation

In [None]:
TRAIN_AUGMENTATION = A.Compose([
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.5),
    A.RandomRotate90(p=0.5),
    A.Rotate(limit=15, p=0.5, border_mode=0),
    A.ShiftScaleRotate(
        shift_limit=0.0625, scale_limit=0.1, rotate_limit=0,
        p=0.5, border_mode=0
    ),
    A.OneOf([
        A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2),
        A.GaussianBlur(blur_limit=(3, 7)),
        A.GaussNoise(var_limit=(10.0, 50.0)),
    ], p=0.5),
],
additional_targets={'mask': 'mask'})

VALIDATION_AUGMENTATION = A.Compose([
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.5),
    A.RandomRotate90(p=0.5),
    A.Rotate(limit=15, p=0.5, border_mode=0),
    A.ShiftScaleRotate(
        shift_limit=0.0625, scale_limit=0.1, rotate_limit=0,
        p=0.5, border_mode=0
    ),
    A.OneOf([
        A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2),
        A.GaussianBlur(blur_limit=(3, 7)),
        A.GaussNoise(var_limit=(10.0, 50.0)),
    ], p=0.5),
],
additional_targets={'mask': 'mask'})

# Setup Dataset

In [None]:
train_dataset = TeethDataset()
train_dataset.load(zip(images_path[:330], labels_path[:330]))
train_dataset.augment(15, TRAIN_AUGMENTATION)
train_dataset.transform(TRAIN_TRANSFORMATION)

val_dataset = TeethDataset()
val_dataset.load(zip(images_path[330:360], labels_path[330:360]))
val_dataset.augment(5, VALIDATION_AUGMENTATION)
val_dataset.transform(VALIDATION_TRANSFORMATION)

test_dataset = TeethDataset()
test_dataset.load(zip(images_path[360:], labels_path[360:]))
test_dataset.transform(TEST_TRANSFORMATION)

# Models

In [None]:
def get_model_size(model):
    param_size = 0
    for param in model.parameters():
        param_size += param.nelement() * param.element_size()
    buffer_size = 0
    for buffer in model.buffers():
        buffer_size += buffer.nelement() * buffer.element_size()
    size_all_mb = (param_size + buffer_size) / (1024 ** 2)
    return size_all_mb

In [None]:
import torch.nn as nn

class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels), 
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),  
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.double_conv(x)

class HeavyUnet(nn.Module):
    def __init__(self):
        super(HeavyUnet, self).__init__()
        self.inc = DoubleConv(1, 64)     
        self.down1 = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(64, 128)
        )
        self.down2 = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(128, 256)
        )
        self.down3 = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(256, 512)
        )
        self.down4 = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(512, 1024)
        )

        self.up1 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
        self.conv_up1 = DoubleConv(1024, 512)

        self.up2 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.conv_up2 = DoubleConv(512, 256)

        self.up3 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.conv_up3 = DoubleConv(256, 128)

        self.up4 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.conv_up4 = DoubleConv(128, 64)

        # Output layer
        self.outc = nn.Conv2d(64, 1, kernel_size=1)

    #######DO NOT CHANGE THIS PART########
    def init(self, path="model.pth"):
        self.load_state_dict(torch.load(path, weights_only=True))
    ######################################
    def save(self, path: str):
        torch.save(copy.deepcopy(self.state_dict()), path)
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:

        x1 = self.inc(x)       
        x2 = self.down1(x1)   
        x3 = self.down2(x2)    
        x4 = self.down3(x3)    
        x5 = self.down4(x4) 

        # Decoder
        x = self.up1(x5)      
        x = torch.cat([x, x4], dim=1) 
        x = self.conv_up1(x)  

        x = self.up2(x)    
        x = torch.cat([x, x3], dim=1)  
        x = self.conv_up2(x)   

        x = self.up3(x)        
        x = torch.cat([x, x2], dim=1)  
        x = self.conv_up3(x) 

        x = self.up4(x)        
        x = torch.cat([x, x1], dim=1) 
        x = self.conv_up4(x)   

        mask = self.outc(x)  
        return mask

get_model_size(HeavyUnet())

In [None]:
# Light weight Unet
class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.inc = DoubleConv(1, 32)         
        self.down1 = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(32, 64)
        )
        self.down2 = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(64, 128)
        )
        self.down3 = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(128, 256)
        )
        self.down4 = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(256, 512)
        )

        self.up1 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.conv_up1 = DoubleConv(512, 256)

        self.up2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.conv_up2 = DoubleConv(256, 128)

        self.up3 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.conv_up3 = DoubleConv(128, 64)

        self.up4 = nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2)
        self.conv_up4 = DoubleConv(64, 32)

        self.outc = nn.Conv2d(32, 1, kernel_size=1)

    #######DO NOT CHANGE THIS PART########
    def init(self, path="model.pth"):
        self.load_state_dict(torch.load(path, weights_only=True))
    ######################################
    def save(self, path: str):
        torch.save(copy.deepcopy(self.state_dict()), path)
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x1 = self.inc(x)     
        x2 = self.down1(x1)    
        x3 = self.down2(x2)    
        x4 = self.down3(x3)    
        x5 = self.down4(x4)   

        x = self.up1(x5)       
        x = torch.cat([x, x4], dim=1) 
        x = self.conv_up1(x)

        x = self.up2(x)       
        x = torch.cat([x, x3], dim=1) 
        x = self.conv_up2(x) 

        x = self.up3(x)        
        x = torch.cat([x, x2], dim=1)  
        x = self.conv_up3(x)  

        x = self.up4(x)        
        x = torch.cat([x, x1], dim=1)  
        x = self.conv_up4(x)  

        mask = self.outc(x)   
        return mask

get_model_size(Model())

# Train model

In [None]:
BEST_MODELS_DIR = "./best_models"
STABLE_MODELS_DIR = "./stable_models"
PLOTS_DIR = "./plots"

! rm -rf "./best_models"
! rm -rf "./stable_models"
! rm -rf "./plots"

os.makedirs(BEST_MODELS_DIR, exist_ok=True)
os.makedirs(STABLE_MODELS_DIR, exist_ok=True)
os.makedirs(PLOTS_DIR, exist_ok=True)

In [None]:

model = Model()

(train_losses,
train_dice_scores,
val_losses,
val_dice_scores,
best_model_path,
best_stable_model_path,
best_model_test_loss,
best_model_test_dice_score,
stable_model_test_loss,
stable_model_test_dice_score) = train(model = model, epochs = 100,
                                      train_dataset = train_dataset, val_dataset = val_dataset, test_dataset = test_dataset,
                                      stable_epochs_count = 5,
                                      stable_dice_score_distance = 0.01,
                                      best_models_dir = BEST_MODELS_DIR,
                                      stable_models_dir = STABLE_MODELS_DIR)

plot_metrics(train_losses, val_losses, 'Loss')
plot_metrics(train_dice_scores, val_dice_scores, 'Scores')



# Save Model

In [None]:
stable_conf = torch.load('./best_models/dice_0.4801_epoch_22_64d.pth')
model = Model()
model.load_state_dict(stable_conf)
model.save('model.pth')

# Submit

In [None]:
import zipfile

with zipfile.ZipFile('submission.zip', 'w') as zipf:
    zipf.write('model.pth')
    zipf.write('model.py')