In [None]:
import os, gc, sys, time, random, copy
from IPython.display import display

import torch
import torch.nn as nn
import torch.nn.functional as F

torch.backends.cuda.enable_mem_efficient_sdp(False)
torch.backends.cuda.enable_flash_sdp(False)

from torchvision import transforms, ops
from torch.utils.data import Dataset, DataLoader

try:
    import segmentation_models_pytorch as smp
except:
    !pip install segmentation-models-pytorch -q
    import segmentation_models_pytorch as smp


from typing import Optional
from tqdm.notebook import tqdm
from PIL import Image
import cv2
from matplotlib import pyplot as plt
import albumentations as A

import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split

def clear_cache():
    torch.cuda.empty_cache()
    gc.collect()

def seed_everything(seed_value):
    torch.manual_seed(seed_value)
    torch.cuda.manual_seed(seed_value)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    random.seed(seed_value)
    np.random.seed(seed_value)

clear_cache()

seed_everything(3)

print(os.cpu_count())

In [None]:
def compute_bbox(mask):
    rows = np.any(mask, axis=1)
    cols = np.any(mask, axis=0)
    if not rows.any() or not cols.any():
        return None  # No non-zero elements, no object detected
    rmin, rmax = np.where(rows)[0][[0, -1]]
    cmin, cmax = np.where(cols)[0][[0, -1]]
    return rmin, rmax, cmin, cmax

def select_points(mask, num_points):
    mask_flat = mask.flatten()
    indices = np.argsort(mask_flat)[::-1][:num_points]
    points = np.unravel_index(indices, mask.shape)
    points = list(map(list, zip(points[1], points[0])))  # Swap x and y coordinates
    return points

def dice_coefficient(preds, targets, threshould = 0.5, smooth = 1.0):
    assert preds.size() == targets.size()
    preds = torch.where(preds > threshould, 1, 0)
    iflat = preds.contiguous().view(-1)
    tflat = targets.contiguous().view(-1)
    intersection = (iflat * tflat).sum()
    dice = (2.0 * intersection + smooth) / (iflat.sum() + tflat.sum() + smooth)
    return dice

def iou_metric(preds, targets, threshould = 0.5, smooth = 1e-6):
    assert preds.size() == targets.size()
    preds = torch.where(preds > threshould, 1, 0)
    intersection = (preds * targets).sum()
    union = preds.sum() + targets.sum() - intersection
    iou = intersection / (union + smooth)
    return iou

In [None]:
# Predcition Visualizaiton
def visualize_prediction(model):
    fig, axes = plt.subplots(1, 7, figsize=(7, 1))
    for ax in axes:
        ax.axis('off')

    rand_data = valid_ds[np.random.randint(len(valid_ds))]
    for idx in range(3):
        axes[idx].imshow(rand_data[0][0], cmap='gray')
    pred = model(torch.Tensor(rand_data[0]).unsqueeze(0).to(device, dtype=WEIGHT_DTYPE)).squeeze().cpu().detach().numpy()
    axes[0].imshow(pred, cmap='copper', alpha=0.4)
    axes[1].imshow(rand_data[1], cmap='copper', alpha=0.4)

    rand_data = train_ds[np.random.randint(len(train_ds))]
    for idx in range(3):
        axes[idx + 4].imshow(rand_data[0][0], cmap='gray')
    pred = model(torch.Tensor(rand_data[0]).unsqueeze(0).to(device, dtype=WEIGHT_DTYPE)).squeeze().cpu().detach().numpy()
    axes[4].imshow(pred, cmap='copper', alpha=0.4)
    axes[5].imshow(rand_data[1], cmap='copper', alpha=0.4)

    plt.tight_layout()
    plt.show()

def make_directory(path):
    assert os.path.exists('/'.join(path.split('/')[:-1])), 'Parant path Doest Exists'
    if os.path.exists(path):  return ('path_exsits')
    os.mkdir(path)



Data

In [None]:
WEIGHT_DTYPE = torch.float32
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:

### SAM
!{sys.executable} -m pip install 'git+https://github.com/facebookresearch/segment-anything.git' -q
if 'sam_vit_h_4b8939.pth' not in os.listdir():
    !wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth

from segment_anything import sam_model_registry, SamPredictor


sam = sam_model_registry['vit_h'](checkpoint='sam_vit_h_4b8939.pth')
sam.to(device=device)

predictor = SamPredictor(sam)

In [None]:
datapoints = [path for path in os.walk('../data/Kvasir-SEG/images')][0][2]
datapoints = [f for f in datapoints if os.path.exists(os.path.join('../data/Kvasir-SEG/images', f)) and os.path.exists(os.path.join('../data/Kvasir-SEG/masks', f))]

datapoints     , test     = train_test_split(datapoints, test_size=0.2,  random_state = 42)
valid,      test     = train_test_split(test,       test_size=0.5,  random_state = 42)
train,      explore  = train_test_split(datapoints, test_size=0.5, random_state = 42)

df = pd.DataFrame(columns = ['file_name', 'split'])
for l, s in zip([train, valid, explore, test], ['train', 'valid', 'explore', 'test']):
    for i in l:
        df.loc[len(df)] = [i, s]


len(train), len(valid), len(explore), len(test)

In [None]:
##############################################
### Extract Feature From SAM, Only have to do once, once and for all.
##############################################
augumnetation = A.Compose([ A.SmallestMaxSize(224),  A.CenterCrop(224,224),  ])
def get_features(path):
    predictor.set_image(
        augumnetation(image = np.array(
          Image.open(os.path.join('../data/Kvasir-SEG/images', path)).convert("RGB"))
      )['image']
    )
    return predictor.features.cpu().numpy()

features = [
                get_features(datapoints[idx])
                for idx in tqdm(range(len(datapoints[:])))
                ]

features = np.array(features)

In [None]:
from multiprocessing import Pool
def load_image(image_file):
    return np.asarray(Image.open(os.path.join('../data/Kvasir-SEG/images', image_file)))
def load_mask(image_file):
    return np.asarray(Image.open(os.path.join('../data/Kvasir-SEG/masks', image_file)))

class ImageDataset(Dataset):
    def __init__(self, image_files, mode = 'train', model = None, size = 224, center_crop = False):

        self.mode = mode

        pool = Pool(processes=os.cpu_count())
        if mode in ('train', 'valid', 'test'):
            self.image_files = df[df.split == mode].file_name.values

        if mode in ('explore', 'batched_explore'):
            self.image_files = df[~df.split.isin(['test', 'valid'])].file_name.values
            self.split = df[~df.split.isin(['test', 'valid'])].split.values
            if mode == 'batched_explore': self.split[self.split == 'explore'] = 'batched_explore'
            self.features = []

            for idx, (img_files, mode) in tqdm(enumerate(zip(self.image_files, self.split))):
                if mode == 'train':
                    self.features.append(None)
                elif mode in ('explore', 'batched_explore'):
                    self.features.append(torch.Tensor(features[idx]))

        self.images = pool.map(load_image, tqdm(self.image_files, total=len(self.image_files)))
        self.masks = pool.map(load_mask, tqdm(self.image_files, total=len(self.image_files)))
        pool.close()
        pool.join()



        self.image_transforms = A.Compose([
            A.VerticalFlip(p=0.5),
            A.HorizontalFlip(p=0.5),
            A.Rotate(p=0.5),
            A.Transpose(p=0.5),
            A.SmallestMaxSize(size),
            A.CenterCrop(size,size) if center_crop else A.RandomCrop(width=size, height=size),
        ])

        self.image_transforms_test = A.Compose([
            A.SmallestMaxSize(size),
            A.CenterCrop(size,size),
        ])

        self.Normalize  = A.Compose([
            A.Normalize(0.5, 0.5),
        ])

        if self.mode == 'batched_explore':
            assert model != None, 'need model input for batched explore'
            model.eval()
            model.to(device)
            for idx, split in tqdm(enumerate(self.split)):
                resiezed_image = self.image_transforms_test(image=self.images[idx])['image']
                if split == 'batched_explore':
                    temp_img = torch.Tensor( np.transpose( resiezed_image, (2, 0, 1) ) )
                    mask_pred = model(temp_img.to(device, dtype=WEIGHT_DTYPE).unsqueeze(0)).detach().cpu().numpy().squeeze()
                    mask_pred = np.where(mask_pred > 0.5, 1, 0)
                    bbox, points = compute_bbox(mask_pred), select_points(mask_pred, 5)
                    if bbox == None: bbox = [0, len(mask_pred), 0, len(mask_pred[0])]
                    predictor.is_image_set, predictor.features = True, self.features[idx].to(device)

                    mask, _, _ = predictor.predict(
                        point_coords=np.array(points), point_labels=np.array([1 for _ in points]),
                        box = np.array(bbox)[np.newaxis, :],
                        multimask_output=False,
                    )
                    mask = (mask / 1.).squeeze()[:, :, np.newaxis]

                    self.masks[idx] = np.repeat(mask, 3, axis = 2)
                else:
                    self.masks[idx] = self.image_transforms_test(image=self.masks[idx])['image']
                self.images[idx] = resiezed_image



    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):

      image_file = self.image_files[idx]

      instance_image = self.images[idx]
      instance_masks = self.masks[idx]
      instance_image, instance_masks = np.array(instance_image), np.array(instance_masks)

      if self.mode in ('train', 'batched_explore'):
          augmented = self.image_transforms(image=instance_image, mask=instance_masks)
          instance_image = augmented['image']
          instance_masks = augmented['mask']
      elif self.mode in ('explore', 'test', 'valid'):
          augmented = self.image_transforms_test(image=instance_image, mask=instance_masks)
          instance_image = augmented['image']
          instance_masks = augmented['mask']

      # Normalize and transpose
      instance_image = np.transpose(self.Normalize(image=instance_image)['image'], (2, 0, 1))#[:1, :, :]
      instance_masks = np.transpose(instance_masks, (2, 0, 1))[0, :, :] / 255

      if self.mode in ('test', 'valid'):
          return instance_image, instance_masks, self.mode, torch.zeros([1, 256, 64, 64])
      if self.mode == 'train' or self.split[idx] == 'train':
          return instance_image, instance_masks, 'train', torch.zeros([1, 256, 64, 64])
      if self.split[idx] == 'batched_explore':
          return instance_image, instance_masks, 'batched_explore', torch.zeros([1, 256, 64, 64])
      return instance_image, instance_masks, 'explore', self.features[idx].cpu()



In [None]:
train_ds, test_ds, valid_ds = ImageDataset(df, 'train'), ImageDataset(df, 'test'), ImageDataset(df, 'valid')

In [None]:
BATCH_SIZE = 8
train_loader   = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True,         num_workers = os.cpu_count())
test_loader    = DataLoader(test_ds,  batch_size=BATCH_SIZE // 2, shuffle=False,     num_workers = os.cpu_count())
valid_loader   = DataLoader(valid_ds, batch_size=BATCH_SIZE // 2, shuffle=False,     num_workers = os.cpu_count())


In [None]:
import segmentation_models_pytorch as smp

model = smp.UnetPlusPlus(
    encoder_name="resnet34",        # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
    encoder_weights=None,     # use `imagenet` pre-trained weights for encoder initialization
    in_channels=3,                  # model input channels (1 for gray-scale images, 3 for RGB, etc.)
    classes=1,                     # model output channels (number of classes in your dataset)
    activation = 'sigmoid',
)
clear_cache()

In [None]:

def loss_func(logits, targets, weight = []):

    if len(weight) == 0: weight = torch.ones(len(logits))
    total_loss = 0

    for l, t, w in zip(logits, targets, weight):

        l, t = l.unsqueeze(0), t.unsqueeze(0)

        dice = smp.losses.DiceLoss('binary')(l, t)
        BCE = nn.BCEWithLogitsLoss()(l, t.unsqueeze(1))

        duce_bce_loss = dice + 0.2 * BCE
        total_loss += duce_bce_loss * w

    return total_loss / len(logits)
=

In [None]:
folder_name = 'Kvasir-SEG-UnetPlusPlus-25train-Trial_3'
assert not os.path.exists(f'./Result/{folder_name}'),  'path already existed'
make_directory(f'../Result')
make_directory(f'../Result/{folder_name}')
make_directory(f'../Result/{folder_name}/train')
make_directory(f'../Result/{folder_name}/train/test_masks')
make_directory(f'../Result/{folder_name}/explore')
make_directory(f'../Result/{folder_name}/explore/test_masks')
make_directory(f'../Result/{folder_name}/batched_explore')
make_directory(f'../Result/{folder_name}/batched_explore/test_masks')

In [None]:

class Trainer:
    def __init__(self, Unet, optimizer, train_dl, valid_dl, test_dl, total_epoch, WEIGHT_DTYPE,
                 pseudo_label_weight = 0.25, per_iter_valid = 60, log_period = 20, lr_lowerbound = 1e-7,
                 save_directory = '', early_stopping_patience = 5,  early_stopping_delta = 1e-2):

        assert save_directory != '' and os.path.exists(save_directory), 'saving directory needs to exists'
        assert early_stopping_delta >= 0,                               'early_stopping_delta needs to be poisitive'

        # Model Related
        self.Unet = Unet.to(device, dtype=WEIGHT_DTYPE)
        self.optimizer = optimizer

        # Data Related
        self.train_dl = train_dl
        self.valid_dl = valid_dl
        self.test_dl = test_dl

        # Utiles
        self.WEIGHT_DTYPE = WEIGHT_DTYPE
        self.total_epoch = total_epoch
        self.per_iter_valid = per_iter_valid
        self.total_step = 0
        self.log_period = log_period
        self.pseudo_label_weight = pseudo_label_weight
        self.lr_lowerbound = lr_lowerbound
        self.save_directory = save_directory
        self.early_stopping_patience = early_stopping_patience
        self.early_stopping_delta = early_stopping_delta

        # result display
        self.result_df = pd.DataFrame(columns=['epoch', 'steps', 'lr', 'Train Loss', 'Valid Loss', 'valid_iou', 'valid_dice'])
        self._display_id = None
        self.display_line = ''
        self.best_result_iter = 0

        print(f'total steps: {len(train_dl) * total_epoch}')

        self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            self.optimizer, mode="min", patience=3, factor=0.5, min_lr = 1e-7
        )
        
    @torch.no_grad()
    def get_masks(self, batch):
        self.Unet.eval()
        imgs, masks = [], []

        for idx in range(len(batch[0])):
            img, mask, mode, feature = batch[0][idx], batch[1][idx], batch[2][idx], batch[3][idx]
            # print(mask.shape)
            if mode == 'explore':
                temp_img = np.transpose(np.array(img), (1, 2, 0))
                temp_img = torch.Tensor(np.transpose(train_ds.image_transforms_test(image=temp_img)['image'], (2, 0, 1)))
                mask_pred = self.Unet(temp_img.to(device, dtype=WEIGHT_DTYPE).unsqueeze(0)).detach().cpu().numpy().squeeze()
                points = select_points(mask_pred, 5)
                mask_pred = np.where(mask_pred > 0.5, 1, 0)
                bbox = compute_bbox(mask_pred)
                if bbox == None: bbox = [0, len(mask_pred), 0, len(mask_pred[0])]
                predictor.is_image_set, predictor.features = True, feature.to(device)

                mask, _, _ = predictor.predict(
                    point_coords=np.array(points), point_labels=np.array([1 for _ in points]),
                    box = np.array(bbox)[np.newaxis, :],
                    multimask_output=False,
                )
                mask = (mask / 1.).squeeze()

            img, mask = np.transpose(np.array(img), (1, 2, 0)), np.array(mask)[:, :, np.newaxis]

            mask = train_ds.image_transforms_test(image=mask)['image']

            augmented = train_ds.image_transforms(image=img, mask=mask)
            img, mask = augmented['image'], augmented['mask']
            img, mask = np.transpose(img, (2, 0, 1)), np.transpose(mask, (2, 0, 1)).squeeze()

            imgs.append(img); masks.append(mask)

        self.Unet.train()
        return torch.from_numpy(np.array(imgs)), torch.from_numpy(np.array(masks))

    @torch.no_grad()
    def valid(self):
        self.Unet.eval()
        valid_pbar = tqdm(self.valid_dl, desc = 'validating', leave = False)
        valid_loss, valid_iou, valid_dice, number_of_instance = [], [], [], 0
        for step, batch in enumerate(valid_pbar):

            pixel_values = batch[0].to(device, dtype=self.WEIGHT_DTYPE)
            prediction = self.Unet(pixel_values)
            target = batch[1].to(device, dtype=self.WEIGHT_DTYPE)

            loss = loss_func(prediction.float(), target.float())

            valid_loss.append(loss.item() * len(batch[0]))

            for p, t in zip(prediction, target):
                valid_dice.append(dice_coefficient(p.squeeze().cpu().detach(), t.cpu().detach()))
                valid_iou.append(iou_metric(p.squeeze().cpu().detach(), t.cpu().detach()))

            number_of_instance += len(batch[0])

        visualize_prediction(self.Unet)

        clear_cache()
        self.Unet.train()

        # RETURE valid loss | valid IOU | valid Dice
        return sum(valid_loss) / number_of_instance, sum(valid_iou) / number_of_instance, sum(valid_dice) / number_of_instance


    @torch.no_grad()
    def test(self):
        self.Unet.eval()
        test_loss, test_iou, test_dice, number_of_instance = [], [], [], 0

        for idx, test_data in enumerate(tqdm(test_ds, desc = 'testing', leave = False)):
            pixel_values = torch.Tensor(test_data[0]).to(device, dtype=self.WEIGHT_DTYPE).unsqueeze(0)
            prediction = self.Unet(pixel_values)
            target = torch.Tensor(test_data[1]).to(device, dtype=self.WEIGHT_DTYPE).unsqueeze(0)

            loss = loss_func(prediction.float(), target.float())
            test_loss.append(loss.item())

            prediction = prediction.squeeze().cpu().detach()
            target = target.squeeze().cpu().detach()
            test_dice.append(dice_coefficient(prediction, target))
            test_iou.append( iou_metric(      prediction, target))

            number_of_instance += 1

            prediction = Image.fromarray((prediction.numpy() * 255).astype('uint8'))
            # display(prediction)
            prediction.save(f"{self.save_directory}/test_masks/{test_ds.image_files[idx].replace('.', '_mask.')}")

        clear_cache()
        self.Unet.train()

        test_comb_loss, test_iou, test_dice = sum(test_loss) / number_of_instance, sum(test_iou) / number_of_instance, sum(test_dice) / number_of_instance

        self.result_df.loc[len(self.result_df)] = [self.total_step // len(self.train_dl), self.total_step, np.round(self.optimizer.param_groups[0]['lr'], 6),
                                                                f'test_result({self.best_result_iter})', np.round(test_comb_loss, 4),
                                                                np.round(test_iou, 4).item(), np.round(test_dice, 4).item()]

        self._display_id.update(self.result_df)

        self.result_df.to_csv(f'{self.save_directory}/UnetPlusPlus.csv')

        return test_comb_loss, test_iou, test_dice

    def train(self):
        self._display_id = display(self.result_df, display_id=True)
        self.rank_display_id = display('', display_id=True)

        recorded_loss = []

        for epoch in range(self.total_epoch):
            self.Unet.train()
            pbar = tqdm(self.train_dl)
            for step, batch in enumerate(pbar):
                imgs, masks = self.get_masks(batch)
                pixel_values = imgs.to(device, dtype=self.WEIGHT_DTYPE)

                prediction = self.Unet(pixel_values)

                target = masks.to(device, dtype=self.WEIGHT_DTYPE)

                loss = loss_func(
                                  prediction.float(),
                                  target.float(),
                                  torch.Tensor(np.where(np.array(list(batch[2])) == 'train', 1, self.pseudo_label_weight))
                                 )

                self.optimizer.zero_grad()

                recorded_loss.append(loss.item())
                pbar.set_description(f"[Loss: {recorded_loss[-1]:.3f}/{np.mean(recorded_loss):.3f}]")

                loss.backward()

                self.optimizer.step()

                self.total_step += 1

                if self.total_step % self.per_iter_valid == 0:

                    valid_comb_loss, valid_iou, valid_dice= self.valid()

                    if valid_comb_loss <= min([x for x in self.result_df['Valid Loss'] if x != ' --- '] + [1000.0]):
                        self.Unet.eval()
                        torch.save(self.Unet.state_dict(), f'{self.save_directory}/UnetPlusPlus.pth')
                        self.Unet.train()
                        self.best_result_iter = epoch

                    self.result_df.loc[len(self.result_df)] = [epoch, self.total_step, np.round(self.optimizer.param_groups[0]['lr'], 6),
                                                               np.round(np.mean(recorded_loss), 4), np.round(valid_comb_loss, 4),
                                                               np.round(valid_iou, 4).item(), np.round(valid_dice, 4).item()]

                    self._display_id.update(self.result_df)

                    ###### check early stopping
                    if self.total_step // self.per_iter_valid > self.early_stopping_patience and \
                        valid_comb_loss >= list([x for x in self.result_df['Valid Loss'] if x != ' --- '])[-self.early_stopping_patience] + self.early_stopping_delta:
                        self.Unet.load_state_dict(torch.load(f'{self.save_directory}/UnetPlusPlus.pth'))
                        self.Unet.eval()

                        test_comb_loss, test_iou, test_dice= self.test()

                        print('early stopped' ,test_comb_loss, test_iou, test_dice)

                        return 'early stopped'

                    self.result_df.to_csv(f'{self.save_directory}/UnetPlusPlus.csv')
                    recorded_loss = []

                    self.scheduler.step(valid_comb_loss)


                if self.total_step % self.log_period == 0:
                    self.result_df.loc[len(self.result_df)] = [epoch, self.total_step, np.round(self.optimizer.param_groups[0]['lr'], 6), np.round(np.mean(recorded_loss), 4), ' --- ', ' --- ',  ' --- ']
                    self._display_id.update(self.result_df)

        # Final Evaluation
        self.Unet.load_state_dict(torch.load(f'{self.save_directory}/UnetPlusPlus.pth'))
        self.Unet.eval()
        test_comb_loss, test_iou, test_dice= self.test()
        print(test_comb_loss, test_iou, test_dice)






In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=5e-5)


trainer = Trainer(
    Unet = model,
    optimizer = optimizer,
    train_dl = train_loader,
    valid_dl = valid_loader,
    test_dl = test_loader,
    total_epoch = 100,
    WEIGHT_DTYPE = WEIGHT_DTYPE,
    per_iter_valid = len(train_loader),
    log_period = 1000000,
    save_directory = f'../Result/{folder_name}/train',
    early_stopping_patience = 10,  early_stopping_delta = 0,
)

trainer.train()

In [None]:
explore_ds = ImageDataset(df, 'explore')
explore_loader = DataLoader(explore_ds, batch_size=BATCH_SIZE, shuffle=True)

In [None]:
fine_tune_model = smp.UnetPlusPlus(
    encoder_name="resnet34",        # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
    encoder_weights=None,     # use `imagenet` pre-trained weights for encoder initialization
    in_channels=3,                  # model input channels (1 for gray-scale images, 3 for RGB, etc.)
    classes=1,                     # model output channels (number of classes in your dataset)
    activation = 'sigmoid',
)
fine_tune_model.load_state_dict(torch.load(f'../Result/{folder_name}/train/UnetPlusPlus.pth'))

optimizer = torch.optim.Adam(fine_tune_model.parameters(), lr=1e-5)

trainer_explore = Trainer(
    Unet = fine_tune_model,
    optimizer = optimizer,
    train_dl = explore_loader,
    valid_dl = valid_loader,
    test_dl = test_loader,
    total_epoch = 100,
    WEIGHT_DTYPE = WEIGHT_DTYPE,
    per_iter_valid = len(explore_loader),
    log_period = 1000000,
    save_directory = f'../Result/{folder_name}/explore',
    early_stopping_patience = 10,  early_stopping_delta = 1e-3, pseudo_label_weight = 0.8,
)

trainer_explore.train()


In [None]:
fine_tune_model = smp.UnetPlusPlus(
    encoder_name="resnet34",        # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
    encoder_weights=None,     # use `imagenet` pre-trained weights for encoder initialization
    in_channels=3,                  # model input channels (1 for gray-scale images, 3 for RGB, etc.)
    classes=1,                     # model output channels (number of classes in your dataset)
    activation = 'sigmoid',
)

fine_tune_model.load_state_dict(torch.load(f'../Result/{folder_name}/train/UnetPlusPlus.pth'))

batched_explore_ds = ImageDataset(df, 'batched_explore', fine_tune_model)
batched_explore_loader = DataLoader(batched_explore_ds, batch_size=BATCH_SIZE, shuffle=True)

In [None]:
optimizer = torch.optim.Adam(fine_tune_model.parameters(), lr=1e-5)

trainer_explore = Trainer(
    Unet = fine_tune_model,
    optimizer = optimizer,
    train_dl = batched_explore_loader,
    valid_dl = valid_loader,
    test_dl = test_loader,
    total_epoch = 100,
    WEIGHT_DTYPE = WEIGHT_DTYPE,
    per_iter_valid = len(batched_explore_loader),
    log_period = 1000000,
    save_directory = f'../Result/{folder_name}/batched_explore',
    early_stopping_patience = 10,  early_stopping_delta = 1e-3, pseudo_label_weight = 0.8,
)
trainer_explore.train()
