In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F


class DoubleConv(nn.Module):
    """(convolution => [BN] => ReLU) * 2"""

    def __init__(self, in_channels, out_channels, mid_channels=None):
        super().__init__()
        if not mid_channels:
            mid_channels = out_channels
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)


class Down(nn.Module):
    """Downscaling with maxpool then double conv"""

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )

    def forward(self, x):
        return self.maxpool_conv(x)


class Up(nn.Module):
    """Upscaling then double conv"""

    def __init__(self, in_channels, out_channels, bilinear=True):
        super().__init__()

        # if bilinear, use the normal convolutions to reduce the number of channels
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
            self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
        else:
            self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
            self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        # input is CHW
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]

        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])
        # if you have padding issues, see
        # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
        # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)


class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        return self.conv(x)

In [3]:
""" Full assembly of the parts to form the complete network """
class UNet(nn.Module):
    def __init__(self, n_channels, n_classes, bilinear=False):
        super(UNet, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear

        self.inc = (DoubleConv(n_channels, 64))
        self.down1 = (Down(64, 128))
        self.down2 = (Down(128, 256))
        self.down3 = (Down(256, 512))
        factor = 2 if bilinear else 1
        self.down4 = (Down(512, 1024 // factor))
        self.up1 = (Up(1024, 512 // factor, bilinear))
        self.up2 = (Up(512, 256 // factor, bilinear))
        self.up3 = (Up(256, 128 // factor, bilinear))
        self.up4 = (Up(128, 64, bilinear))
        self.outc = (OutConv(64, n_classes))

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        logits = self.outc(x)
        return logits

    def use_checkpointing(self):
        self.inc = torch.utils.checkpoint(self.inc)
        self.down1 = torch.utils.checkpoint(self.down1)
        self.down2 = torch.utils.checkpoint(self.down2)
        self.down3 = torch.utils.checkpoint(self.down3)
        self.down4 = torch.utils.checkpoint(self.down4)
        self.up1 = torch.utils.checkpoint(self.up1)
        self.up2 = torch.utils.checkpoint(self.up2)
        self.up3 = torch.utils.checkpoint(self.up3)
        self.up4 = torch.utils.checkpoint(self.up4)
        self.outc = torch.utils.checkpoint(self.outc)

In [4]:
import logging
import numpy as np
import torch
from PIL import Image
from functools import lru_cache
from functools import partial
from itertools import repeat
from multiprocessing import Pool
from os import listdir
from os.path import splitext, isfile, join
from pathlib import Path
from torch.utils.data import Dataset
from tqdm import tqdm


def load_image(filename):
    ext = splitext(filename)[1]
    if ext == '.npy':
        return Image.fromarray(np.load(filename))
    elif ext in ['.pt', '.pth']:
        return Image.fromarray(torch.load(filename).numpy())
    else:
        return Image.open(filename)


def unique_mask_values(idx, mask_dir, mask_suffix):
    mask_file = list(mask_dir.glob(idx + mask_suffix + '.npy'))[0]
    mask = np.asarray(load_image(mask_file))
    if mask.ndim == 2:
        return np.unique(mask)
    elif mask.ndim == 3:
        mask = mask.reshape(-1, mask.shape[-1])
        return np.unique(mask, axis=0)
    else:
        raise ValueError(f'Loaded masks should have 2 or 3 dimensions, found {mask.ndim}')


class BasicDataset(Dataset):
    def __init__(self, images_dir: str, mask_dir: str, scale: float = 1.0, mask_suffix: str = ''):
        self.images_dir = Path(images_dir)
        self.mask_dir = Path(mask_dir)
        assert 0 < scale <= 1, 'Scale must be between 0 and 1'
        self.scale = scale
        self.mask_suffix = mask_suffix

        self.ids = [splitext(file)[0] for file in listdir(images_dir)]
        if not self.ids:
            raise RuntimeError(f'No input file found in {images_dir}, make sure you put your images there')

        logging.info(f'Creating dataset with {len(self.ids)} examples')
        logging.info('Scanning mask files to determine unique values')
        with Pool() as p:
            unique = list(tqdm(
                p.imap(partial(unique_mask_values, mask_dir=self.mask_dir, mask_suffix=self.mask_suffix), self.ids),
                total=len(self.ids)
            ))

        self.mask_values = list(sorted(np.unique(np.concatenate(unique), axis=0).tolist()))
        logging.info(f'Unique mask values: {self.mask_values}')

    def __len__(self):
        return len(self.ids)

    @staticmethod
    def preprocess(mask_values, pil_img, scale, is_mask):
        w, h = pil_img.size
        newW, newH = int(scale * w), int(scale * h)
        assert newW > 0 and newH > 0, 'Scale is too small, resized images would have no pixel'
        pil_img = pil_img.resize((newW, newH), resample=Image.NEAREST if is_mask else Image.BICUBIC)
        img = np.asarray(pil_img)

        if is_mask:
            mask = np.zeros((newH, newW), dtype=np.int64)
            for i, v in enumerate(mask_values):
                if img.ndim == 2:
                    mask[img == v] = i
                else:
                    mask[(img == v).all(-1)] = i

            return mask

        else:
            if img.ndim == 2:
                img = img[np.newaxis, ...]
            else:
                img = img.transpose((2, 0, 1))

            if (img > 1).any():
                img = img / 255.0

            return img

    def __getitem__(self, idx):
        name = self.ids[idx]
        mask_file = list(self.mask_dir.glob(name + self.mask_suffix + '.*'))
        img_file = list(self.images_dir.glob(name + '.*'))

        assert len(img_file) == 1, f'Either no image or multiple images found for the ID {name}: {img_file}'
        assert len(mask_file) == 1, f'Either no mask or multiple masks found for the ID {name}: {mask_file}'
        mask = load_image(mask_file[0])
        img = load_image(img_file[0])

        assert img.size == mask.size, \
            f'Image and mask {name} should be the same size, but are {img.size} and {mask.size}'

        img = self.preprocess(self.mask_values, img, self.scale, is_mask=False)
        mask = self.preprocess(self.mask_values, mask, self.scale, is_mask=True)

        return {
            'image': torch.as_tensor(img.copy()).float().contiguous(),
            'mask': torch.as_tensor(mask.copy()).long().contiguous()
        }

In [5]:
import torch
from torch import Tensor


def dice_coeff(input: Tensor, target: Tensor, reduce_batch_first: bool = False, epsilon: float = 1e-6):
    # Average of Dice coefficient for all batches, or for a single mask
    assert input.size() == target.size()
    assert input.dim() == 3 or not reduce_batch_first

    sum_dim = (-1, -2) if input.dim() == 2 or not reduce_batch_first else (-1, -2, -3)

    inter = 2 * (input * target).sum(dim=sum_dim)
    sets_sum = input.sum(dim=sum_dim) + target.sum(dim=sum_dim)
    sets_sum = torch.where(sets_sum == 0, inter, sets_sum)

    dice = (inter + epsilon) / (sets_sum + epsilon)
    return dice.mean()


def multiclass_dice_coeff(input: Tensor, target: Tensor, reduce_batch_first: bool = False, epsilon: float = 1e-6):
    # Average of Dice coefficient for all classes
    return dice_coeff(input.flatten(0, 1), target.flatten(0, 1), reduce_batch_first, epsilon)


def dice_loss(input: Tensor, target: Tensor, multiclass: bool = False):
    # Dice loss (objective to minimize) between 0 and 1
    fn = multiclass_dice_coeff if multiclass else dice_coeff
    return 1 - fn(input, target, reduce_batch_first=True)

In [6]:
import torch
import torch.nn.functional as F
from tqdm import tqdm


@torch.inference_mode()
def evaluate(net, dataloader, device, amp):
    net.eval()
    num_val_batches = len(dataloader)
    dice_score = 0

    # iterate over the validation set
    with torch.autocast(device.type if device.type != 'mps' else 'cpu', enabled=amp):
        for batch in tqdm(dataloader, total=num_val_batches, desc='Validation round', unit='batch', leave=False):
            image, mask_true = batch['image'], batch['mask']

            # move images and labels to correct device and type
            image = image.to(device=device, dtype=torch.float32, memory_format=torch.channels_last)
            mask_true = mask_true.to(device=device, dtype=torch.long)

            # predict the mask
            mask_pred = net(image)

            if net.n_classes == 1:
                assert mask_true.min() >= 0 and mask_true.max() <= 1, 'True mask indices should be in [0, 1]'
                mask_pred = (F.sigmoid(mask_pred) > 0.5).float()
                # compute the Dice score
                dice_score += dice_coeff(mask_pred, mask_true, reduce_batch_first=False)
            else:
                assert mask_true.min() >= 0 and mask_true.max() < net.n_classes, 'True mask indices should be in [0, n_classes['
                # convert to one-hot format
                mask_true = F.one_hot(mask_true, net.n_classes).permute(0, 3, 1, 2).float()
                mask_pred = F.one_hot(mask_pred.argmax(dim=1), net.n_classes).permute(0, 3, 1, 2).float()
                # compute the Dice score, ignoring background
                dice_score += multiclass_dice_coeff(mask_pred[:, 1:], mask_true[:, 1:], reduce_batch_first=False)

    net.train()
    return dice_score / max(num_val_batches, 1)

In [7]:
import argparse
import logging
import os
import random
import sys
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
import torchvision.transforms.functional as TF
from pathlib import Path
from torch import optim
from torch.utils.data import DataLoader, random_split
from tqdm import tqdm
from torch.optim.lr_scheduler import OneCycleLR

import wandb

dir_img = Path('/teamspace/studios/this_studio/Ragnor/dataset/train_new/image')
dir_mask = Path('/teamspace/studios/this_studio/Ragnor/dataset/train_new/label')
dir_checkpoint = Path('./checkpoints_0430_3/')

dir_val_img = Path('/teamspace/studios/this_studio/Ragnor/dataset/val_new/image')
dir_val_mask = Path('/teamspace/studios/this_studio/Ragnor/dataset/val_new/label')

def train_model(
        model,
        device,
        epochs: int = 20,
        batch_size: int = 64,
        learning_rate: float = 1e-5,
        val_percent: float = 0.,
        save_checkpoint: bool = True,
        img_scale: float = 1,
        amp: bool = False,
        weight_decay: float = 1e-8,
        momentum: float = 0.999,
        gradient_clipping: float = 1.0,
):
    # 1. Create dataset

    dataset = BasicDataset(dir_img, dir_mask, img_scale)
    raw_val_dataset = BasicDataset(dir_val_img, dir_val_mask, img_scale)
    val_set, _ = random_split(raw_val_dataset, [0.1, 0.9], generator=torch.Generator().manual_seed(0))
    # 2. Split into train / validation partitions
    n_train = len(dataset)
    n_val = len(val_set)
    # train_set, val_set = random_split(dataset, [n_train, n_val], generator=torch.Generator().manual_seed(0))

    # 3. Create data loaders
    loader_args = dict(batch_size=batch_size, num_workers=os.cpu_count(), pin_memory=True)
    train_loader = DataLoader(dataset, shuffle=True, **loader_args)
    # train_loader = DataLoader(train_set, shuffle=True, **loader_args)
    val_loader = DataLoader(val_set, shuffle=False, drop_last=True, **loader_args)

    # (Initialize logging)
    experiment = wandb.init(project='U-Net', resume='allow', anonymous='must')
    experiment.config.update(
        dict(epochs=epochs, batch_size=batch_size, learning_rate=learning_rate,
             val_percent=val_percent, save_checkpoint=save_checkpoint, img_scale=img_scale, amp=amp)
    )

    logging.info(f'''Starting training:
        Epochs:          {epochs}
        Batch size:      {batch_size}
        Learning rate:   {learning_rate}
        Training size:   {n_train}
        Validation size: {n_val}
        Checkpoints:     {save_checkpoint}
        Device:          {device.type}
        Images scaling:  {img_scale}
        Mixed Precision: {amp}
    ''')

    # 4. Set up the optimizer, the loss, the learning rate scheduler and the loss scaling for AMP
    optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-8)
    scheduler = OneCycleLR(optimizer, max_lr=0.01, steps_per_epoch=len(train_loader), epochs=epochs)  # goal: maximize Dice score
    grad_scaler = torch.cuda.amp.GradScaler(enabled=amp)
    criterion = nn.CrossEntropyLoss() if model.n_classes > 1 else nn.BCEWithLogitsLoss()
    global_step = 0

    # 5. Begin training
    for epoch in range(1, epochs + 1):
        model.train()
        epoch_loss = 0
        with tqdm(total=n_train, desc=f'Epoch {epoch}/{epochs}', unit='img') as pbar:
            for batch in train_loader:
                images, true_masks = batch['image'], batch['mask']

                assert images.shape[1] == model.n_channels, \
                    f'Network has been defined with {model.n_channels} input channels, ' \
                    f'but loaded images have {images.shape[1]} channels. Please check that ' \
                    'the images are loaded correctly.'

                images = images.to(device=device, dtype=torch.float32, memory_format=torch.channels_last)
                true_masks = true_masks.to(device=device, dtype=torch.long)

                with torch.autocast(device.type if device.type != 'mps' else 'cpu', enabled=amp):
                    masks_pred = model(images)
                    if model.n_classes == 1:
                        loss = criterion(masks_pred.squeeze(1), true_masks.float())
                        loss += dice_loss(F.sigmoid(masks_pred.squeeze(1)), true_masks.float(), multiclass=False)
                    else:
                        loss = criterion(masks_pred, true_masks)
                        loss += dice_loss(
                            F.softmax(masks_pred, dim=1).float(),
                            F.one_hot(true_masks, model.n_classes).permute(0, 3, 1, 2).float(),
                            multiclass=True
                        )

                optimizer.zero_grad(set_to_none=True)
                grad_scaler.scale(loss).backward()
                grad_scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), gradient_clipping)
                grad_scaler.step(optimizer)
                grad_scaler.update()

                pbar.update(images.shape[0])
                global_step += 1
                epoch_loss += loss.item()
                experiment.log({
                    'train loss': loss.item(),
                    'step': global_step,
                    'epoch': epoch
                })
                pbar.set_postfix(**{'loss (batch)': loss.item()})

                # Evaluation round
                division_step = (n_train // (2 * batch_size))
                if division_step > 0:
                    if global_step % division_step == 0:
                        histograms = {}
                        for tag, value in model.named_parameters():
                            tag = tag.replace('/', '.')
                            if not (torch.isinf(value) | torch.isnan(value)).any():
                                histograms['Weights/' + tag] = wandb.Histogram(value.data.cpu())
                            if not (torch.isinf(value.grad) | torch.isnan(value.grad)).any():
                                histograms['Gradients/' + tag] = wandb.Histogram(value.grad.data.cpu())
                        val_score = evaluate(model, val_loader, device, amp)
                        print("Val Score", val_score)
                        scheduler.step(val_score)

                        logging.info('Validation Dice score: {}'.format(val_score))
                        try:
                            experiment.log({
                                'learning rate': optimizer.param_groups[0]['lr'],
                                'validation Dice': val_score,
                                'images': wandb.Image(images[0].cpu()),
                                'masks': {
                                    'true': wandb.Image(true_masks[0].float().cpu()),
                                    'pred': wandb.Image(masks_pred.argmax(dim=1)[0].float().cpu()),
                                },
                                'step': global_step,
                                'epoch': epoch,
                                **histograms
                            })
                        except:
                            pass

        if save_checkpoint:
            Path(dir_checkpoint).mkdir(parents=True, exist_ok=True)
            state_dict = model.state_dict()
            state_dict['mask_values'] = dataset.mask_values
            torch.save(state_dict, str(dir_checkpoint / 'checkpoint_epoch{}.pth'.format(epoch)))
            logging.info(f'Checkpoint {epoch} saved!')


if __name__ == '__main__':
    # Change here to adapt to your data
    # n_channels=3 for RGB images
    # n_classes is the number of probabilities you want to get per pixel
    model = UNet(n_channels=3, n_classes=49, bilinear=False)
    model = model.to(memory_format=torch.channels_last)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    logging.info(f'Network:\n'
                 f'\t{model.n_channels} input channels\n'
                 f'\t{model.n_classes} output channels (classes)\n'
                 f'\t{"Bilinear" if model.bilinear else "Transposed conv"} upscaling')

    model.to(device=device)
    train_model(model, device)

100%|██████████| 22000/22000 [00:01<00:00, 16125.67it/s]
100%|██████████| 21978/21978 [00:01<00:00, 18316.90it/s]
ERROR:wandb.jupyter:Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33manony-mouse-39621772799264332[0m. Use [1m`wandb login --relogin`[0m to force relogin




Val Score tensor(0.7058, device='cuda:0')




Val Score tensor(0.8782, device='cuda:0')


Epoch 1/20: 100%|██████████| 22000/22000 [07:32<00:00, 48.64img/s, loss (batch)=0.124]
Epoch 2/20:  49%|████▉     | 10880/22000 [03:54<29:00,  6.39img/s, loss (batch)=0.0383]

Val Score tensor(0.9484, device='cuda:0')


Epoch 2/20:  99%|█████████▉| 21824/22000 [07:26<00:27,  6.40img/s, loss (batch)=0.0218]

Val Score tensor(0.9608, device='cuda:0')


Epoch 2/20: 100%|██████████| 22000/22000 [07:30<00:00, 48.80img/s, loss (batch)=0.0242]
Epoch 3/20:  49%|████▉     | 10752/22000 [03:51<28:59,  6.47img/s, loss (batch)=0.0175]

Val Score tensor(0.9617, device='cuda:0')


Epoch 3/20:  99%|█████████▊| 21696/22000 [07:22<00:46,  6.49img/s, loss (batch)=0.014] 

Val Score tensor(0.9703, device='cuda:0')


Epoch 3/20: 100%|██████████| 22000/22000 [07:28<00:00, 49.01img/s, loss (batch)=0.0141]
Epoch 4/20:  48%|████▊     | 10624/22000 [03:48<29:21,  6.46img/s, loss (batch)=0.0123]

Val Score tensor(0.9751, device='cuda:0')


Epoch 4/20:  98%|█████████▊| 21568/22000 [07:21<01:07,  6.42img/s, loss (batch)=0.00954]

Val Score tensor(0.9766, device='cuda:0')


Epoch 4/20: 100%|██████████| 22000/22000 [07:30<00:00, 48.87img/s, loss (batch)=0.0107] 
Epoch 5/20:  48%|████▊     | 10496/22000 [03:48<30:30,  6.28img/s, loss (batch)=0.0082] 

Val Score tensor(0.9789, device='cuda:0')


Epoch 5/20:  97%|█████████▋| 21440/22000 [07:19<01:27,  6.42img/s, loss (batch)=0.00971]

Val Score tensor(0.9781, device='cuda:0')


Epoch 5/20: 100%|██████████| 22000/22000 [07:29<00:00, 48.92img/s, loss (batch)=0.00937]
Epoch 6/20:  47%|████▋     | 10368/22000 [03:44<30:11,  6.42img/s, loss (batch)=0.00835]

Val Score tensor(0.9821, device='cuda:0')


Epoch 6/20:  97%|█████████▋| 21312/22000 [07:16<01:47,  6.40img/s, loss (batch)=0.00716]

Val Score tensor(0.9833, device='cuda:0')


Epoch 6/20: 100%|██████████| 22000/22000 [07:29<00:00, 48.97img/s, loss (batch)=0.00704]
Epoch 7/20:  47%|████▋     | 10240/22000 [03:42<30:17,  6.47img/s, loss (batch)=0.00829]

Val Score tensor(0.9835, device='cuda:0')


Epoch 7/20:  96%|█████████▋| 21184/22000 [07:16<02:08,  6.37img/s, loss (batch)=0.00758]

Val Score tensor(0.9838, device='cuda:0')


Epoch 7/20: 100%|██████████| 22000/22000 [07:30<00:00, 48.79img/s, loss (batch)=0.00666]
Epoch 8/20:  46%|████▌     | 10112/22000 [03:41<31:26,  6.30img/s, loss (batch)=0.0075] 

Val Score tensor(0.9842, device='cuda:0')


Epoch 8/20:  96%|█████████▌| 21056/22000 [07:15<02:26,  6.42img/s, loss (batch)=0.00612]

Val Score tensor(0.9808, device='cuda:0')


Epoch 8/20: 100%|██████████| 22000/22000 [07:31<00:00, 48.67img/s, loss (batch)=0.00953]
Epoch 9/20:  45%|████▌     | 9984/22000 [03:38<31:05,  6.44img/s, loss (batch)=0.00688]

Val Score tensor(0.9824, device='cuda:0')


Epoch 9/20:  95%|█████████▌| 20928/22000 [07:10<02:44,  6.51img/s, loss (batch)=0.00572]

Val Score tensor(0.9865, device='cuda:0')


Epoch 9/20: 100%|██████████| 22000/22000 [07:29<00:00, 48.98img/s, loss (batch)=0.00564]
Epoch 10/20:  45%|████▍     | 9856/22000 [03:36<31:41,  6.39img/s, loss (batch)=0.00566]

Val Score tensor(0.9861, device='cuda:0')


Epoch 10/20:  95%|█████████▍| 20800/22000 [07:09<03:06,  6.42img/s, loss (batch)=0.00532]

Val Score tensor(0.9873, device='cuda:0')


Epoch 10/20: 100%|██████████| 22000/22000 [07:30<00:00, 48.79img/s, loss (batch)=0.00549]
Epoch 11/20:  44%|████▍     | 9728/22000 [03:35<32:14,  6.34img/s, loss (batch)=0.00522]

Val Score tensor(0.9889, device='cuda:0')


Epoch 11/20:  94%|█████████▍| 20672/22000 [07:07<03:27,  6.40img/s, loss (batch)=0.00797]

Val Score tensor(0.9785, device='cuda:0')


Epoch 11/20: 100%|██████████| 22000/22000 [07:30<00:00, 48.83img/s, loss (batch)=0.00678]
Epoch 12/20:  44%|████▎     | 9600/22000 [03:31<32:17,  6.40img/s, loss (batch)=0.00584]

Val Score tensor(0.9852, device='cuda:0')


Epoch 12/20:  93%|█████████▎| 20544/22000 [07:03<03:45,  6.45img/s, loss (batch)=0.00686]

Val Score tensor(0.9862, device='cuda:0')


Epoch 12/20: 100%|██████████| 22000/22000 [07:29<00:00, 49.00img/s, loss (batch)=0.00756]
Epoch 13/20:  43%|████▎     | 9472/22000 [03:28<32:10,  6.49img/s, loss (batch)=0.00519]

Val Score tensor(0.9898, device='cuda:0')


Epoch 13/20:  93%|█████████▎| 20416/22000 [07:01<04:09,  6.34img/s, loss (batch)=0.00561]

Val Score tensor(0.9879, device='cuda:0')


Epoch 13/20: 100%|██████████| 22000/22000 [07:29<00:00, 48.96img/s, loss (batch)=0.00515]
Epoch 14/20:  42%|████▏     | 9344/22000 [03:25<33:01,  6.39img/s, loss (batch)=0.00507]

Val Score tensor(0.9891, device='cuda:0')


Epoch 14/20:  92%|█████████▏| 20288/22000 [06:57<04:24,  6.46img/s, loss (batch)=0.00484]

Val Score tensor(0.9904, device='cuda:0')


Epoch 14/20: 100%|██████████| 22000/22000 [07:27<00:00, 49.21img/s, loss (batch)=0.00487]
Epoch 15/20:  42%|████▏     | 9216/22000 [03:24<32:46,  6.50img/s, loss (batch)=0.0049] 

Val Score tensor(0.9912, device='cuda:0')


Epoch 15/20:  92%|█████████▏| 20160/22000 [06:58<04:48,  6.38img/s, loss (batch)=0.00508]

Val Score tensor(0.9911, device='cuda:0')


Epoch 15/20: 100%|██████████| 22000/22000 [07:29<00:00, 48.89img/s, loss (batch)=0.00506]
Epoch 16/20:  41%|████▏     | 9088/22000 [03:21<32:56,  6.53img/s, loss (batch)=0.00491]

Val Score tensor(0.9911, device='cuda:0')


Epoch 16/20:  91%|█████████ | 20032/22000 [06:52<05:04,  6.46img/s, loss (batch)=0.00489]

Val Score tensor(0.9896, device='cuda:0')


Epoch 16/20: 100%|██████████| 22000/22000 [07:26<00:00, 49.30img/s, loss (batch)=0.00576]
Epoch 17/20:  41%|████      | 8960/22000 [03:21<34:03,  6.38img/s, loss (batch)=0.00509]

Val Score tensor(0.9913, device='cuda:0')


Epoch 17/20:  90%|█████████ | 19904/22000 [06:54<05:24,  6.45img/s, loss (batch)=0.0047] 

Val Score tensor(0.9912, device='cuda:0')


Epoch 17/20: 100%|██████████| 22000/22000 [07:30<00:00, 48.79img/s, loss (batch)=0.00443]
Epoch 18/20:  40%|████      | 8832/22000 [03:19<34:03,  6.44img/s, loss (batch)=0.00713]

Val Score tensor(0.9857, device='cuda:0')


Epoch 18/20:  90%|████████▉ | 19776/22000 [06:51<05:47,  6.40img/s, loss (batch)=0.00543]

Val Score tensor(0.9873, device='cuda:0')


Epoch 18/20: 100%|██████████| 22000/22000 [07:29<00:00, 48.89img/s, loss (batch)=0.0078] 
Epoch 19/20:  40%|███▉      | 8704/22000 [03:16<34:15,  6.47img/s, loss (batch)=0.00521]

Val Score tensor(0.9886, device='cuda:0')


Epoch 19/20:  89%|████████▉ | 19648/22000 [06:48<06:01,  6.51img/s, loss (batch)=0.00499]

Val Score tensor(0.9923, device='cuda:0')


Epoch 19/20: 100%|██████████| 22000/22000 [07:27<00:00, 49.14img/s, loss (batch)=0.00401]
Epoch 20/20:  39%|███▉      | 8576/22000 [03:13<34:29,  6.49img/s, loss (batch)=0.0044] 

Val Score tensor(0.9931, device='cuda:0')


Epoch 20/20:  89%|████████▊ | 19520/22000 [06:45<06:20,  6.53img/s, loss (batch)=0.00451]

Val Score tensor(0.9930, device='cuda:0')


Epoch 20/20: 100%|██████████| 22000/22000 [07:28<00:00, 49.08img/s, loss (batch)=0.00516]


In [5]:
import argparse
import logging
import os
import random
import sys
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
import torchvision.transforms.functional as TF
from pathlib import Path
from torch import optim
from torch.utils.data import DataLoader, random_split
from tqdm import tqdm
from torch.optim.lr_scheduler import OneCycleLR

import wandb
import matplotlib.pyplot as plt

dir_img = Path('/teamspace/studios/this_studio/Ragnor/dataset/train_new/image')
dir_mask = Path('/teamspace/studios/this_studio/Ragnor/dataset/train_new/label')
dir_checkpoint = Path('./checkpoints/')


if __name__ == '__main__':
    dataset = BasicDataset(dir_img, dir_mask, 0.5)
    loader_args = dict(batch_size=8, num_workers=os.cpu_count(), pin_memory=True)
    print("A ")
    train_loader = DataLoader(dataset, shuffle=True, **loader_args)
    print("B ")
    for batch in train_loader:
        print("C ")
        images, true_masks = batch['image'], batch['mask']
        print("Image: ",images.shape)
        print("Mask: ",true_masks.shape)
        img = images[1,:,:,:]
        mask = true_masks[1,:,:]
        plt.figure(figsize=(10, 5))
            
        plt.subplot(1, 2, 1)  # 1 row, 2 columns, 1st subplot
        plt.imshow(img.permute(1,2,0))
        plt.axis('off')

        # Plot the second image
        plt.subplot(1, 2, 2)  # 1 row, 2 columns, 2nd subplot
        plt.imshow(mask)
        plt.axis('off')

        plt.show()
    


100%|██████████| 22000/22000 [00:01<00:00, 18456.10it/s]


A 
B 


In [None]:
pip install wandb

Collecting wandb
  Downloading wandb-0.16.6-py3-none-any.whl.metadata (10 kB)
Collecting GitPython!=3.1.29,>=1.0.0 (from wandb)
  Downloading GitPython-3.1.43-py3-none-any.whl.metadata (13 kB)
Collecting sentry-sdk>=1.0.0 (from wandb)
  Downloading sentry_sdk-2.0.1-py2.py3-none-any.whl.metadata (9.9 kB)
Collecting docker-pycreds>=0.4.0 (from wandb)
  Downloading docker_pycreds-0.4.0-py2.py3-none-any.whl.metadata (1.8 kB)
Collecting setproctitle (from wandb)
  Downloading setproctitle-1.3.3-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (9.9 kB)
Collecting appdirs>=1.4.3 (from wandb)
  Downloading appdirs-1.4.4-py2.py3-none-any.whl.metadata (9.0 kB)
Collecting gitdb<5,>=4.0.1 (from GitPython!=3.1.29,>=1.0.0->wandb)
  Downloading gitdb-4.0.11-py3-none-any.whl.metadata (1.2 kB)
Collecting smmap<6,>=3.0.1 (from gitdb<5,>=4.0.1->GitPython!=3.1.29,>=1.0.0->wandb)
  Downloading smmap-5.0.1-py3-none-any.whl.metadata (4.3 kB)
Downloadi