# Connection
*Kaggle do not support user input, that's why I cannot use Rclone to connect to my OneDrive storage for code and data*

*-> My solution: Save data on Kaggle input directory, code sync by Github (git clone) and save checkpoint on Kaggle working directory*

In [1]:
import os


os.chdir("/kaggle/working")

In [2]:
!cd ./ && pwd && ls

/kaggle/working
__notebook_source__.ipynb


In [3]:
# from kaggle_secrets import UserSecretsClient


# user_secrets = UserSecretsClient()
# github_token = user_secrets.get_secret("github_token")

# os.system(f"git clone -b develop https://{github_token}@github.com/Kokoroou/self-supervised-segmentation")


# Preparation

## Data

### Main config to prepare data

In [None]:
# Config
data_source = "kokoroou/polypgen2021"
data_dir = "./input"

### Preprocessing data

In [None]:
%pip install --upgrade pip
%pip install kaggle
%pip install torchvision

In [None]:
from pathlib import Path, PurePosixPath
import time

from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import ImageFolder


# Download data from Kaggle (Only need on Colab)
# !kaggle datasets download -d {data_source} -p {data_dir}

# Define the path to the directory containing the train images
current_dir = Path(os.getcwd())
source_dir = Path(data_dir) / "polypgen2021" / "PolypGen2021_MultiCenterData_v3"

# Define the path to the file containing the train image names
train_filepath = source_dir / "train_autoencoder.txt"
test_filepath = source_dir / "test_autoencoder.txt"

# Open the file with names of training, testing image file, then make DataLoader
with open(train_filepath, "r") as f:
    train_filenames = f.read().splitlines()
with open(test_filepath, "r") as f:
    test_filenames = f.read().splitlines()

# Create a custom dataset class to load the images
class CustomDataset(ImageFolder):
    def __init__(self, root, names, transform=None):
        super().__init__(root, transform=transform)
        self.samples = [
            (Path(root, PurePosixPath(name)), 0) for name in names
        ]

# Define the transformations to apply to the images
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

start = time.time()
print("Create dataset...")
# Create an instance of the custom dataset
train_dataset = CustomDataset(str(source_dir), train_filenames, transform=transform)
test_dataset = CustomDataset(str(source_dir), test_filenames, transform=transform)

# Create a data loader to load the images in batches
batch_size = 64
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)
print(f"Create dataset. Finish in {time.time() - start}")


## Code

### Main config to prepare code

In [None]:
# Config
repo_url = "https://github.com/Kokoroou/self-supervised-segmentation"
repo_dir = "./src"


### Prepare code

In [None]:
# Download repo from Github
!git clone -b develop {repo_url} {repo_dir}

# Install requirements
%pip install -r "./src/self-supervised-segmentation/requirements.txt"

## Output

### Main config to save result

In [None]:
# Config
checkpoint_dir = "./checkpoint"

### Connect to output storage


In [None]:
# Default storage in Kaggle


# Running

### Training

In [None]:
# Use code on Github to train


In [4]:
from math import sqrt

import torch


def unpatchify(patches, patch_size=16):
    """
    Combining patches into images.

    Args:
        patches: Input tensor with size (
            batch,
            (height / patch_size) * (width / patch_size),
            channels * patch_size * patch_size
            )
        patch_size: Patch size
    Returns:
        A batch of images with size (batch, channels, height, width)
    """
    batch_size, num_patches, total_patch_size = patches.shape
    channels = total_patch_size // patch_size ** 2

    # Count number of patches in height and width
    height_count = width_count = int(sqrt(num_patches))

    # Calculate height and width of the image
    height = width = height_count * patch_size

    # Raise error if num_patches is not a square number
    assert height_count * width_count == num_patches

    # Unpatching patches into images
    patches = patches.reshape((batch_size, height_count, width_count, patch_size, patch_size, channels))
    patches = torch.einsum('nhwpqc->nchpwq', patches)
    images = patches.reshape((batch_size, channels, height, width))

    return images


def random_masking(x, mask_ratio):
    """
    Perform per-sample random masking by per-sample shuffling.
    Per-sample shuffling is done by argsort random noise.

    Args:
        x: torch.Tensor, shape [N, L, D]
           The input sequence with embedded vector of elements of samples in batch.
           N: batch size, L: embedded vector length of a sample (number of elements), D: encoder embedding dimension

        mask_ratio: float
           The ratio of elements to be masked (removed) from each sample.

    Returns:
        kept_patches: torch.Tensor, shape [N, L_masked, D]
           The masked sequence with a subset of elements removed for each sample.
           L_masked: masked sequence length after removing elements.

        mask: torch.Tensor, shape [N, L]
           Binary mask indicating which elements are kept (0) or removed (1) from each sequence.

        ids_restore: torch.Tensor, shape [N, L]
           Indices used to restore the original order of the elements after shuffling.

    """
    batch_size, num_patches, dimension = x.shape
    keep_count = int(num_patches * (1 - mask_ratio))

    noise = torch.rand(batch_size, num_patches, device=x.device)  # noise in [0, 1]

    # Sort noise for each sample
    ids_shuffle = torch.argsort(noise, dim=1)  # ascend: small is keep, large is remove
    ids_restore = torch.argsort(ids_shuffle, dim=1)

    # Keep the first subset
    ids_keep = ids_shuffle[:, :keep_count]
    kept_patches = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, dimension))

    # Generate the binary mask: 0 is keep, 1 is remove
    mask = torch.ones([batch_size, num_patches], device=x.device)
    mask[:, :keep_count] = 0
    # Un-shuffle to get the binary mask
    mask = torch.gather(mask, dim=1, index=ids_restore)

    return kept_patches, mask, ids_restore


In [5]:
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# Position embedding utils
# --------------------------------------------------------

import numpy as np

import torch

# --------------------------------------------------------
# 2D sine-cosine position embedding
# References:
# Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py
# MoCo v3: https://github.com/facebookresearch/moco-v3
# --------------------------------------------------------
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
    """
    grid_size: int of the grid height and width
    return:
    pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
    """
    grid_h = np.arange(grid_size, dtype=np.float32)
    grid_w = np.arange(grid_size, dtype=np.float32)
    grid = np.meshgrid(grid_w, grid_h)  # here w goes first
    grid = np.stack(grid, axis=0)

    grid = grid.reshape([2, 1, grid_size, grid_size])
    pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
    if cls_token:
        pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
    return pos_embed


def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
    assert embed_dim % 2 == 0

    # use half of dimensions to encode grid_h
    emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0])  # (H*W, D/2)
    emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1])  # (H*W, D/2)

    emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
    return emb


def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
    """
    embed_dim: output dimension for each position
    pos: a list of positions to be encoded: size (M,)
    out: (M, D)
    """
    assert embed_dim % 2 == 0
    omega = np.arange(embed_dim // 2, dtype=float)
    omega /= embed_dim / 2.
    omega = 1. / 10000**omega  # (D/2,)

    pos = pos.reshape(-1)  # (M,)
    out = np.einsum('m,d->md', pos, omega)  # (M, D/2), outer product

    emb_sin = np.sin(out) # (M, D/2)
    emb_cos = np.cos(out) # (M, D/2)

    emb = np.concatenate([emb_sin, emb_cos], axis=1)  # (M, D)
    return emb


# --------------------------------------------------------
# Interpolate position embeddings for high-resolution
# References:
# DeiT: https://github.com/facebookresearch/deit
# --------------------------------------------------------
def interpolate_pos_embed(model, checkpoint_model):
    if 'pos_embed' in checkpoint_model:
        pos_embed_checkpoint = checkpoint_model['pos_embed']
        embedding_size = pos_embed_checkpoint.shape[-1]
        num_patches = model.patch_embed.num_patches
        num_extra_tokens = model.pos_embed.shape[-2] - num_patches
        # height (== width) for the checkpoint position embedding
        orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
        # height (== width) for the new position embedding
        new_size = int(num_patches ** 0.5)
        # class_token and dist_token are kept unchanged
        if orig_size != new_size:
            print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
            extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
            # only the position tokens are interpolated
            pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
            pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
            pos_tokens = torch.nn.functional.interpolate(
                pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
            pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
            new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
            checkpoint_model['pos_embed'] = new_pos_embed


In [6]:
from functools import partial
from typing import Any

import torch
import torch.nn as nn
from timm.models.vision_transformer import PatchEmbed, Block
from torchinfo import summary

class MaskedAutoencoderViT(nn.Module):
    """
    Masked Autoencoder with VisionTransformer backbone
    """
    def __init__(self,
                 img_size: int = 224,
                 patch_size: int = 16,
                 in_chans: int = 3,
                 encoder_embed_dim: int = 1024, encoder_depth: int = 24, encoder_num_heads: int = 16,
                 decoder_embed_dim: int = 512, decoder_depth: int = 8, decoder_num_heads: int = 16,
                 mlp_ratio: float = 4.,
                 norm_layer: Any = nn.LayerNorm,
                 norm_pix_loss: bool = False):
        """
        Initialize model structure

        Args:
            img_size: Image size of input image (e.g. 224 for 224x224 image)
            patch_size: Size of each patch
            in_chans: Number of input channels (e.g. 3 for RGB)
            encoder_embed_dim: Embedding dimension of encoder
            encoder_depth: Number of encoder blocks
            encoder_num_heads: Number of heads in encoder
            decoder_embed_dim: Embedding dimension of decoder
            decoder_depth: Number of decoder blocks
            decoder_num_heads: Number of heads in decoder
            mlp_ratio: Ratio of MLP hidden dim to embedding dim
            norm_layer: Normalization layer
            norm_pix_loss: Whether to normalize pixel loss
        """
        super().__init__()

        self.img_size = img_size
        self.patch_size = patch_size
        self.in_chans = in_chans
        self.encoder_embed_dim = encoder_embed_dim
        self.encoder_depth = encoder_depth
        self.encoder_num_heads = encoder_num_heads
        self.decoder_embed_dim = decoder_embed_dim
        self.decoder_depth = decoder_depth
        self.decoder_num_heads = decoder_num_heads
        self.mlp_ratio = mlp_ratio
        self.norm_layer = norm_layer
        self.norm_pix_loss = norm_pix_loss

        # --------------------------------------------------------------------------
        # MAE encoder specifics
        self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, encoder_embed_dim)
        num_patches = self.patch_embed.num_patches

        self.cls_token = nn.Parameter(torch.zeros(1, 1, encoder_embed_dim))
        # fixed sin-cos embedding
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, encoder_embed_dim), requires_grad=False)

        self.blocks = nn.ModuleList([
            Block(encoder_embed_dim, encoder_num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer)
            for _ in range(encoder_depth)])
        self.norm = norm_layer(encoder_embed_dim)
        # --------------------------------------------------------------------------

        # --------------------------------------------------------------------------
        # MAE decoder specifics
        self.decoder_embed = nn.Linear(encoder_embed_dim, decoder_embed_dim, bias=True)

        self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))

        # fixed sin-cos embedding
        self.decoder_pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, decoder_embed_dim), requires_grad=False)

        self.decoder_blocks = nn.ModuleList([
            Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer)
            for _ in range(decoder_depth)])

        self.decoder_norm = norm_layer(decoder_embed_dim)
        self.decoder_pred = nn.Linear(decoder_embed_dim, patch_size ** 2 * in_chans, bias=True)  # decoder to patch
        # --------------------------------------------------------------------------

        self.initialize_weights()

    def initialize_weights(self):
        # initialization
        # initialize (and freeze) pos_embed by sin-cos embedding
        pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.patch_embed.num_patches**.5),
                                            cls_token=True)
        self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))

        decoder_pos_embed = get_2d_sincos_pos_embed(self.decoder_pos_embed.shape[-1],
                                                    int(self.patch_embed.num_patches**.5), cls_token=True)
        self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0))

        # initialize patch_embed like nn.Linear (instead of nn.Conv2d)
        w = self.patch_embed.proj.weight.data
        torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))

        # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
        torch.nn.init.normal_(self.cls_token, std=.02)
        torch.nn.init.normal_(self.mask_token, std=.02)

        # initialize nn.Linear and nn.LayerNorm
        self.apply(self._init_weights)

    @staticmethod
    def _init_weights(m):
        if isinstance(m, nn.Linear):
            # we use xavier_uniform following official JAX ViT:
            torch.nn.init.xavier_uniform_(m.weight)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    def forward_encoder(self, x, mask_ratio):
        """
        Forward pass through the encoder of the neural network model.

        Args:
            x (torch.Tensor): Input image tensor of shape (N, C, H, W).
            mask_ratio (float): Ratio of elements to mask during random masking.

        Returns:
            torch.Tensor: Encoded output tensor.
            torch.Tensor: Mask tensor indicating the masked elements.
            torch.Tensor: Restored indices tensor for masked elements.
        """
        # Divide image into patches and embed them
        x = self.patch_embed(x)

        # Add positional embedding without classification token
        x = x + self.pos_embed[:, 1:, :]

        # Masking image patches, only keep patches that unmasked and info for restoring
        x, mask, ids_restore = random_masking(x, mask_ratio)

        # Append cls token
        cls_token = self.cls_token + self.pos_embed[:, :1, :]
        cls_tokens = cls_token.expand(x.shape[0], -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)

        # Apply Transformer blocks
        for blk in self.blocks:
            x = blk(x)
        x = self.norm(x)

        return x, mask, ids_restore

    def forward_decoder(self, x, ids_restore):
        # embed tokens
        x = self.decoder_embed(x)

        # append mask tokens to sequence
        mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1)
        x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1)  # no cls token
        x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2]))  # unshuffle
        x = torch.cat([x[:, :1, :], x_], dim=1)  # append cls token

        # add pos embed
        x = x + self.decoder_pos_embed

        # apply Transformer blocks
        for blk in self.decoder_blocks:
            x = blk(x)
        x = self.decoder_norm(x)

        # predictor projection
        x = self.decoder_pred(x)

        # remove cls token
        x = x[:, 1:, :]

        x = unpatchify(x, self.patch_size)

        return x

    # def forward_loss(self, imgs, pred, mask):
    #     """
    #     imgs: [N, 3, H, W]
    #     pred: [N, L, p*p*3]
    #     mask: [N, L], 0 is keep, 1 is remove,
    #     """
    #     target = patchify(imgs, self.patch_size)
    #     if self.norm_pix_loss:
    #         mean = target.mean(dim=-1, keepdim=True)
    #         var = target.var(dim=-1, keepdim=True)
    #         target = (target - mean) / (var + 1.e-6)**.5
    #
    #     loss = (pred - target) ** 2
    #     loss = loss.mean(dim=-1)  # [N, L], mean loss per patch
    #
    #     loss = (loss * mask).sum() / mask.sum()  # mean loss on removed patches
    #     return loss

    def forward(self, imgs, mask_ratio=0.75):
        """
        Args:
            imgs: Batch of images (shape: [N, C, H, W])
            mask_ratio: Ratio of masked patches

        Returns:
            loss: Masked autoencoder loss
            pred: Predicted patches
            mask: Mask of removed patches
        """
        latent, mask, ids_restore = self.forward_encoder(imgs, mask_ratio)
        pred = self.forward_decoder(latent, ids_restore)  # [N, L, p*p*3]
        # loss = self.forward_loss(imgs, pred, mask)
        return pred

In [7]:
# import gc
# from tqdm import tqdm

# for i in tqdm(range(10)):
#     device = "cuda" if torch.cuda.is_available() else "cpu"
#     image_size = 224
#     model = MaskedAutoencoderViT(img_size=image_size).to(device)

#     with torch.no_grad():
#         result = summary(model, (64, 3, image_size, image_size))
#     #     print(result)
#     #     print("\nNumber of parameters:", sum(p.numel() for p in model.parameters() if p.requires_grad))

#         # Can only free allocated GPU memory at the moment, but cannot free RAM
#         del result
#         gc.collect()
#         torch.cuda.empty_cache()
#         print(torch.cuda.memory_summary())


In [8]:
import argparse


parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)

parser.add_argument("-b", "--batch_size", type=int, default=64, help="Batch size")
parser.add_argument("-e", "--epochs", type=int, default=100, help="Number of training epochs")
parser.add_argument("-lr", "--learning_rate", type=int, default=0.001, help="Learning rate")  
parser.add_argument("-o", "--output_dir", type=str, default="./output_dir", help="Directory for save checkpoint")

args = parser.parse_known_args()[0]

In [9]:
import wandb
from kaggle_secrets import UserSecretsClient


user_secrets = UserSecretsClient()
wandb_api_key = user_secrets.get_secret("wandb_api_key")
wandb.login(key=wandb_api_key)

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
[34m[1mwandb[0m: Paste an API key from your profile and hit enter, or press ctrl+c to quit:

  ········································


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


True

In [10]:
# I use wandb to log the hyper-parameters, training results, and model weights.
import random
from datetime import datetime
from pathlib import Path

import wandb


current_dir = Path(os.getcwd())

# Start a new wandb run to track this script
wandb.init(
    job_type="train",
    dir=current_dir,
    config=args,
    project="semantic-segmentation",
    name="mae_" + datetime.now().strftime("%Y%m%d_%H%M%S"),
    notes="Masked Autoencoder",
    mode="online"
)


[34m[1mwandb[0m: Currently logged in as: [33mkokoroou[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [11]:
from pathlib import Path
import time

import torch


def save_model(args, epoch, model, optimizer, criterion):
    output_dir = Path(args.output_dir)
    checkpoint_path = output_dir / f'checkpoint-{epoch}.pth'

    to_save = {
        'model': model.state_dict(),
        'optimizer': optimizer.state_dict(),
        'criterion': criterion.state_dict(),
        'epoch': epoch,
        'args': args,
    }
    
    torch.save(to_save, checkpoint_path)


def save_best_model(args, epoch, model, optimizer, criterion):
    output_dir = Path(args.output_dir)
    checkpoint_path = output_dir / f'checkpoint-{epoch}.pth'
    best_checkpoint_path = output_dir / 'checkpoint-best.pth'

    if not checkpoint_path.exists():
        to_save = {
            'model': model.state_dict(),
            'optimizer': optimizer.state_dict(),
            'criterion': criterion.state_dict(),
            'epoch': epoch,
            'args': args,
        }

        torch.save(to_save, checkpoint_path)
    else:
        os.rename(checkpoint_path, best_checkpoint_path)


def delete_checkpoint(args, epoch):
    # Delete previous checkpoints, except best and last one
    output_dir = Path(args.output_dir)

    checkpoint_filename = f"checkpoint-{epoch}.pth"
    checkpoint_path = output_dir / checkpoint_filename
    checkpoint_path.unlink(missing_ok=True)
    
    
def save_epoch_result(args, epoch, loss,
                      model, optimizer, criterion):
    global best_loss
    
    start = time.time()
    print(f"Save trained model for epoch {epoch}...")
    save_model(args=args, model=model, optimizer=optimizer,
               criterion=criterion, epoch=epoch)
    print(f"Save trained model for epoch {epoch}. Finish in {time.time() - start}")

    if loss.item() < best_loss:
        best_loss = loss.item()

        with open(Path(current_dir) / "info.txt", "w") as f:
            f.write(f"Best epoch: {epoch}\nLoss: {loss.item()}")
          
        delete_checkpoint(args=args, epoch="best")

        start = time.time()
        print("Save best trained model...")
        save_best_model(
            args=args, model=model, optimizer=optimizer,
            criterion=criterion, epoch=epoch
        )
        print(f"Save best trained model. Finish in {time.time() - start}")

    delete_checkpoint(args=args, epoch=epoch - 1)
    

In [13]:
import argparse
import os
from pathlib import Path, PurePosixPath

import torch
from torch import nn, optim
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import ImageFolder
from tqdm import tqdm
import time
from threading import Thread

device = "cuda" if torch.cuda.is_available() else "cpu"

start = time.time()
print("Initialize and add model to GPU...")
model = MaskedAutoencoderViT()
model= nn.DataParallel(model)  # Use all GPU
model.to(device)
print(f"Initialize and add model to GPU. Finish in {time.time() - start}")

# Define the path to the directory containing the train images
source_dir = Path(current_dir).parent / "input" / "polypgen2021" / "PolypGen2021_MultiCenterData_v3"

# Define the path to the file containing the train image names
train_filepath = source_dir / "train_autoencoder.txt"
test_filepath = source_dir / "test_autoencoder.txt"

# Open the file with names of training, testing image file, then make DataLoader
with open(train_filepath, "r") as f:
    train_filenames = f.read().splitlines()
with open(test_filepath, "r") as f:
    test_filenames = f.read().splitlines()

# Create a custom dataset class to load the images
class CustomDataset(ImageFolder):
    def __init__(self, root, names, transform=None):
        super().__init__(root, transform=transform)
        self.samples = [
            (Path(root, PurePosixPath(name)), 0) for name in names
        ]

# Define the transformations to apply to the images
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

start = time.time()
print("Create dataset...")
# Create an instance of the custom dataset
train_dataset = CustomDataset(str(source_dir), train_filenames, transform=transform)
test_dataset = CustomDataset(str(source_dir), test_filenames, transform=transform)

# Create a data loader to load the images in batches
batch_size = args.batch_size
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)
print(f"Create dataset. Finish in {time.time() - start}")

# Define the loss function and optimizer
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=wandb.config["learning_rate"])

best_loss = 1.0

os.makedirs(args.output_dir, exist_ok=True)

thread = None

for epoch in range(args.epochs):
    # Iterate over the data loader batches
    for inputs, _ in tqdm(train_loader, desc=f"Epoch {epoch}"):
        inputs = inputs.to(device)

        optimizer.zero_grad()

        # Forward pass
        outputs = model(inputs)
        loss = criterion(outputs, inputs)

        # Backward pass and optimization
        loss.backward()
        optimizer.step()
    
    # log metrics to wandb
    wandb.log({"loss": loss.item()})

    print(f"Epoch {epoch}: loss = {loss.item()}")
    
    if thread != None:
        thread.join()
    
    # create a thread to save checkpoint in background
    thread = Thread(target=save_epoch_result, 
                    args=(args, epoch, loss, model, optimizer, criterion))
    thread.start()

# [Optional] Finish the wandb run, necessary in notebooks
wandb.finish()


Initialize and add model to GPU...
Initialize and add model to GPU. Finish in 14.497633457183838
Create dataset...
Create dataset. Finish in 12.126553297042847


Epoch 0: 100%|██████████| 101/101 [04:13<00:00,  2.51s/it]


Epoch 0: loss = 0.6975340843200684
Save trained model for epoch 0...


Epoch 1:   0%|          | 0/101 [00:00<?, ?it/s]

Save trained model for epoch 0. Finish in 12.108598947525024
Save best trained model...
Save best trained model. Finish in 0.0011594295501708984


Epoch 1: 100%|██████████| 101/101 [04:07<00:00,  2.45s/it]


Epoch 1: loss = 0.7071645259857178
Save trained model for epoch 1...


Epoch 2:   0%|          | 0/101 [00:00<?, ?it/s]

Save trained model for epoch 1. Finish in 13.894358158111572


Epoch 2: 100%|██████████| 101/101 [04:12<00:00,  2.50s/it]


Epoch 2: loss = 0.7043012976646423
Save trained model for epoch 2...


Epoch 3:   0%|          | 0/101 [00:00<?, ?it/s]

Save trained model for epoch 2. Finish in 13.358202457427979


Epoch 3: 100%|██████████| 101/101 [03:57<00:00,  2.35s/it]


Epoch 3: loss = 0.6022506356239319
Save trained model for epoch 3...


Epoch 4:   0%|          | 0/101 [00:00<?, ?it/s]

Save trained model for epoch 3. Finish in 13.787100076675415
Save best trained model...
Save best trained model. Finish in 0.0030465126037597656


Epoch 4: 100%|██████████| 101/101 [03:43<00:00,  2.21s/it]


Epoch 4: loss = 0.6079214811325073
Save trained model for epoch 4...


Epoch 5:   0%|          | 0/101 [00:00<?, ?it/s]

Save trained model for epoch 4. Finish in 14.178593158721924


Epoch 5: 100%|██████████| 101/101 [03:43<00:00,  2.21s/it]


Epoch 5: loss = 0.5678946375846863
Save trained model for epoch 5...


Epoch 6:   0%|          | 0/101 [00:00<?, ?it/s]

Save trained model for epoch 5. Finish in 14.413940906524658
Save best trained model...
Save best trained model. Finish in 0.0029480457305908203


Epoch 6: 100%|██████████| 101/101 [03:38<00:00,  2.16s/it]


Epoch 6: loss = 0.37699443101882935
Save trained model for epoch 6...

Epoch 7:   0%|          | 0/101 [00:00<?, ?it/s]




Epoch 7:   0%|          | 0/101 [00:08<?, ?it/s]

KeyboardInterrupt



In [None]:
# Giảm GPU Memory bị chiếm dụng khi dừng training giữa chừng
import gc

gc.collect()
torch.cuda.empty_cache()

Save trained model for epoch 6. Finish in 13.078415632247925
Save best trained model...
Save best trained model. Finish in 0.0013885498046875
