In [1]:
import numpy as np
import torch.nn as nn
import torch
from psimage import PSImage
import torch.utils.data as data
from torchvision.transforms import transforms
import matplotlib.pyplot as plt
from torchsummary import summary
import wandb
import logging
from pathlib import Path
from torch import optim
from datetime import datetime
from tqdm import tqdm
from PIL import Image
import gc

psimage tile cache size was set up to 300 tiles


In [2]:
import time

In [3]:
from nets_parts.SegNet_torch import SegNet

In [4]:
torch.cuda.is_available()

True

In [5]:
device = torch.device('cuda:1')

In [6]:
patch_on_epoch = 4000
patch_on_epoch_valid = 128
layer_num = 1
patch_size = 256
batch_size = 32

In [7]:
from nets_parts.path_to_psi import train_files, train_files_json, valid_files, valid_files_json, test_files, test_files_json

In [8]:
class ImageDataset(data.Dataset):
    def __init__(self, paths: list[str], layer_num: int, patch_size: int, 
                 batch_size: int, patch_on_epoch: int, transforms, device):
        self.psi_images = [PSImage(i) for i in paths]
        self.layer_num = layer_num
        self.patch_size = patch_size
        self.batch_size = batch_size
        self.patch_on_epoch = patch_on_epoch
        self.transforms = transforms
        self.device = device
        self.img_num_to_slice = min(3, len(paths))
        self.imgs_slice = []
        self.indexes = [i for i in range(len(self.psi_images))]
        self.index_chose = []
        for i in range(self.img_num_to_slice):
            self.add_one_slice()
        
    def __getitem__(self, index):
        if index % 200 == 0:
            self.remove_one_slice()
            self.add_one_slice()
        cycle_iters = 0
        while True:
            cycle_iters += 1
            img_ind = np.random.randint(self.img_num_to_slice)
            # left_top_corner
            w, h = self.imgs_slice[img_ind].shape[:2]
            lt_c = np.random.randint(0, [w - self.patch_size, h - self.patch_size])
            img = self.imgs_slice[img_ind][lt_c[0]: lt_c[0] + self.patch_size, lt_c[1] : lt_c[1] + self.patch_size]
            img = torch.tensor(img.astype(np.float32))
            if cycle_iters % 4 == 0 :
                self.remove_one_slice()
                self.add_one_slice()
            if img.mean() < 0.9 or cycle_iters > 15:
                img = self.transforms(img)
                img = transforms.ToTensor()(np.array(img))
                img = img.to(device)
                break

        return img, img

    def __len__(self):
        return self.patch_on_epoch

    def remove_one_slice(self):
        # remove element from index list, slice list
        pos_img_ind = np.random.randint(len(self.imgs_slice))
        img_ind = self.index_chose.pop(pos_img_ind)
        self.imgs_slice.pop(pos_img_ind)
        self.indexes.append(img_ind)
    
    def add_one_slice(self):
        while True:
            #slice
            pos_img_ind = np.random.randint(len(self.indexes))
            img_ind = self.indexes.pop(pos_img_ind)
            self.index_chose.append(img_ind)
            slice_size_h = self.psi_images[img_ind].layout.img_h //self.layer_num // 10 * 1
            h = self.psi_images[img_ind].layout.img_h //self.layer_num - slice_size_h
            slice_size_w = self.psi_images[img_ind].layout.img_w //self.layer_num // 10 * 1
            w = self.psi_images[img_ind].layout.img_w //self.layer_num - slice_size_w
            # left_top_corner
            lt_c = np.random.randint(0, [h, w])
            cur_slice = self.psi_images[img_ind].get_region_from_layer(
                self.layer_num, (lt_c[0], lt_c[1]), (lt_c[0] + slice_size_h, lt_c[1] + slice_size_w)
            ) / 255
            if cur_slice.mean() < 0.85:
                self.imgs_slice.append(cur_slice)
                break
            self.indexes.append(img_ind)
            self.index_chose.pop()


def train_loader_creator(train_list: list[str], patch_on_epoch: int = 1800, layer_num: int = 1, 
                         patch_size:int = 1024, batch_size: int = 16, device = 'cpu'):

    # [NO] do not use normalize here cause it's very hard to converge
    # [NO] do not use colorjitter cause it lead to performance drop in both train set and val set

    # [?] guassian blur will lead to a significantly drop in train loss while val loss remain the same
    augmentation = [
        transforms.GaussianBlur(5, sigma=(0.1, 2.0)),
        transforms.RandomHorizontalFlip()
    ]

    train_trans = transforms.Compose(augmentation)
    train_dataset = ImageDataset(paths=train_list, layer_num=layer_num, patch_size=patch_size, 
                 batch_size=batch_size, patch_on_epoch=patch_on_epoch, transforms=train_trans,
                 device=device)

    train_loader = torch.utils.data.DataLoader(
                    train_dataset,
                    shuffle=False,
                    batch_size=batch_size,
                    num_workers=0)

    return train_loader

In [9]:
train_loader = train_loader_creator(train_list=train_files, patch_on_epoch=patch_on_epoch, 
                                    layer_num=layer_num, patch_size=patch_size, 
                                    batch_size=batch_size, device="cuda:1")

In [10]:
valid_loader = train_loader_creator(train_list=valid_files, patch_on_epoch=patch_on_epoch_valid, 
                                    layer_num=layer_num, patch_size=patch_size, batch_size=batch_size, device="cuda:1")

In [11]:
import torch
import torch.nn as nn
from efficientnet_pytorch import EfficientNet

class AutoEncoder(nn.Module):
    def __init__(self, encoder_name='efficientnet-b2'):
        super(AutoEncoder, self).__init__()
        self.encoder = EfficientNet.from_name(encoder_name)
        self.decoder = nn.Sequential(
            nn.Conv2d(1408, 512, kernel_size=3, padding=1, padding_mode="reflect"),
            nn.BatchNorm2d(512),
            nn.ReLU(True),

            nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1, padding_mode="reflect"),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),

            nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            nn.Conv2d(128, 128, kernel_size=3, padding=1, padding_mode="reflect"),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),

            nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.Conv2d(64, 64, kernel_size=3, padding=1, padding_mode="reflect"),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),

            nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2),
            nn.BatchNorm2d(32),
            nn.ReLU(True),
            nn.Conv2d(32, 32, kernel_size=3, padding=1, padding_mode="reflect"),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),

            nn.ConvTranspose2d(32, 16, kernel_size=2, stride=2),
            nn.BatchNorm2d(16),
            nn.ReLU(True),
            nn.Conv2d(16, 3, kernel_size=3, padding=1, padding_mode="reflect"),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.encoder.extract_features(x)
        x = self.decoder(x)
        return x

    def freeze_encoder(self):
        for param in self.encoder.parameters():
            param.requires_grad = False

    def unfreeze_encoder(self):
        for param in self.encoder.parameters():
            param.requires_grad = True

# Создаем модель
model = AutoEncoder().to(device)

In [12]:
save_checkpoint = True
dir_checkpoint = Path("./checkpoints_bn")
epochs = 60
batch_size = 32
learning_rate = 1e-3
clipping_value = 1 # arbitrary value of your choosing
weight_decay = 0
epoch_to_start_sched = 10

In [13]:
import os
image_path = "images_bn"
os.makedirs(image_path, exist_ok=True)

In [14]:
import gc

In [15]:
# (Initialize logging)
experiment = wandb.init(project='Effnet_Autoencoder', resume='allow', anonymous='must')
experiment.config.update(
    dict(epochs=epochs, batch_size=batch_size, learning_rate=learning_rate,
        save_checkpoint=save_checkpoint)
)
now = datetime.now()

logging.info(f'''Starting training:
    Epochs:          {epochs}
    Batch size:      {batch_size}
    Learning rate:   {learning_rate}
    Checkpoints:     {save_checkpoint}
    Device:          {device.type}
    weight_decay:    {weight_decay}
''')

# 4. Set up the optimizer, the loss, the learning rate scheduler and the loss scaling for AMP
optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay, foreach=False)
criterion = nn.MSELoss()
global_step = 0
valid_step = 0
scheduler = None
# 5. Begin training
for epoch in range(1, epochs + 1):
    if epoch >= epoch_to_start_sched:
        if epoch == epoch_to_start_sched:
            scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9)
        else:
            if scheduler is None:
                scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9)
            scheduler.step()
    model.train()
    epoch_loss = 0
    with tqdm(total=patch_on_epoch, desc=f'Epoch {epoch}/{epochs}', unit='img') as pbar:
        for batch in train_loader:
            images, true_masks = batch
            """
            assert images.shape[1] == model.n_channels, \
                f'Network has been defined with {model.n_channels} input channels, ' \
                f'but loaded images have {images.shape[1]} channels. Please check that ' \
                'the images are loaded correctly.'
            """
            images = images.to(device=device, dtype=torch.float32, memory_format=torch.channels_last)
            true_masks = true_masks.to(device=device)

            masks_pred = model(images)
            loss = criterion(masks_pred, true_masks)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            torch.nn.utils.clip_grad_norm_(model.parameters(), clipping_value)
            pbar.update(images.shape[0])
            global_step += 1
            epoch_loss += loss.item()
            experiment.log({
                'train loss': loss.item(),
                'step': global_step,
                'epoch': epoch
            })
            pbar.set_postfix(**{'loss (batch)': loss.item(), 'lr': optimizer.param_groups[0]["lr"]})
            # Evaluation round
            del masks_pred, images, loss
            gc.collect()
            torch.cuda.empty_cache()
            experiment.log({
                'learning rate': optimizer.param_groups[0]['lr'],
                'step': global_step,
                'epoch': epoch
            })
    
    model.eval()
    epoch_loss = 0
    with tqdm(total=patch_on_epoch_valid, desc=f'Epoch {epoch}/{epochs}', unit='img') as pbar:
        display_images = True
        for batch in valid_loader:
            images, true_masks = batch
            """
            assert images.shape[1] == model.n_channels, \
                f'Network has been defined with {model.n_channels} input channels, ' \
                f'but loaded images have {images.shape[1]} channels. Please check that ' \
                'the images are loaded correctly.'
            """
            images = images.to(device=device, dtype=torch.float32, memory_format=torch.channels_last)
            true_masks = true_masks.to(device=device)

            masks_pred = model(images)
            loss = criterion(masks_pred, true_masks)

            pbar.update(images.shape[0])
            valid_step += 1
            epoch_loss += loss.item()
            experiment.log({
                'valid step loss': loss.item(),
                'valid_step': valid_step,
                'epoch': epoch
            })
            del loss
            gc.collect()
            torch.cuda.empty_cache()
            if display_images:
                display_images = False
                plt.subplots_adjust(hspace=3, wspace=3)
                fig, ax = plt.subplots(nrows=2, ncols=4, figsize=(8, 4))
                plt.subplots_adjust(hspace=0.03, wspace=0.03)
                for i, row in enumerate(ax):
                    for j, col in enumerate(row):
                        # show image
                        col.set_axis_off()
                        if j % 2 == 0:
                            col.imshow(images[i * 2 + j // 2].cpu().permute(1, 2, 0).numpy())
                        # show pred
                        if j % 2 == 1:
                            col.imshow(masks_pred[i * 2 + j // 2].cpu().detach().permute(1, 2, 0).numpy())
                plt.savefig(f'{image_path}/epoches_{epoch}.png')
                plt.close()
            del masks_pred, images
            gc.collect()
            torch.cuda.empty_cache()
        experiment.log({
            'valid epoch loss': epoch_loss,
            'valid_step': valid_step,
            "epoch": epoch
        })
        valid_step += 1

    if save_checkpoint and epoch > 7 and epoch % 2 == 0:
        Path(dir_checkpoint).mkdir(parents=True, exist_ok=True)
        state_dict = model.state_dict()
        torch.save(state_dict, str(dir_checkpoint / 'checkpoint_epoch{}.pth'.format(epoch)))
        logging.info(f'Checkpoint {epoch} saved!')

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33manony-mouse-53924265597001403[0m. Use [1m`wandb login --relogin`[0m to force relogin


Epoch 1/60: 100%|██████████| 4000/4000 [03:33<00:00, 18.77img/s, loss (batch)=0.0217, lr=0.001]
Epoch 1/60: 100%|██████████| 128/128 [00:10<00:00, 12.24img/s]
Epoch 2/60: 100%|██████████| 4000/4000 [03:21<00:00, 19.89img/s, loss (batch)=0.0177, lr=0.001]
Epoch 2/60: 100%|██████████| 128/128 [00:17<00:00,  7.26img/s]
Epoch 3/60: 100%|██████████| 4000/4000 [03:08<00:00, 21.21img/s, loss (batch)=0.0178, lr=0.001]
Epoch 3/60: 100%|██████████| 128/128 [00:08<00:00, 14.42img/s]
Epoch 4/60: 100%|██████████| 4000/4000 [02:38<00:00, 25.16img/s, loss (batch)=0.0138, lr=0.001]
Epoch 4/60: 100%|██████████| 128/128 [00:05<00:00, 23.62img/s]
Epoch 5/60: 100%|██████████| 4000/4000 [03:17<00:00, 20.28img/s, loss (batch)=0.0153, lr=0.001]
Epoch 5/60: 100%|██████████| 128/128 [00:06<00:00, 19.33img/s]
Epoch 6/60: 100%|██████████| 4000/4000 [02:23<00:00, 27.89img/s, loss (batch)=0.0165, lr=0.001] 
Epoch 6/60: 100%|██████████| 128/128 [00:05<00:00, 21.54img/s]
Epoch 7/60: 100%|██████████| 4000/4000 [02:36

<Figure size 640x480 with 0 Axes>