In [1]:
!nvidia-smi

Sat Apr  1 12:59:02 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 527.27       Driver Version: 527.27       CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name            TCC/WDDM | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Quadro RTX 3000    WDDM  | 00000000:01:00.0 Off |                  N/A |
| N/A   68C    P8     7W /  N/A |      0MiB /  6144MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [13]:
from torchvision.models.segmentation.deeplabv3 import DeepLabHead
from torchvision import models
import torch
import torch.nn as nn
import torch.nn.functional as F

def create_model(in_ch):
    model = models.segmentation.deeplabv3_resnet101(pretrained=False,
                                                    progress=False, weights=None, weights_backbone=None)
    model.load_state_dict(torch.load('/kaggle/input/deeplabv3-pretrained/deeplabv3_pretrained.pt'))
    model.backbone.conv1 = nn.Conv2d(in_ch, 64, 7, 2, 3, bias=False)

    model.classifier = DeepLabHead(2048, 1)
    # Set the model in training mode
    model.train()
    return model

In [14]:
def dice_loss(pred, target, smooth = 1.):
    pred = pred.contiguous()
    target = target.contiguous()    

    intersection = (pred * target).sum(dim=2).sum(dim=2)
    
    loss = (1 - ((2. * intersection + smooth) / (pred.sum(dim=2).sum(dim=2) + target.sum(dim=2).sum(dim=2) + smooth)))
    
    return loss.mean()

def calc_loss(pred, target, bce_weight = 1):
    bce = F.binary_cross_entropy_with_logits(pred, target)

    pred = F.sigmoid(pred)
    dice = dice_loss(pred, target)

    loss = bce * bce_weight + dice * (1 - bce_weight)

    return loss

In [29]:
import numpy as np
import torch.utils.data as data
import os
import PIL.Image as Image
from tqdm import tqdm
import glob
import torch.nn as nn
from torch import optim
import torch
from torch.utils.tensorboard import SummaryWriter

# ================== random patch dataset ==============================
class RandomOpt():
    def __init__(self):
        self.SHARED_HEIGHT = 4096  # Height to resize all papyrii
        self.BUFFER = 64  # Half-size of papyrus patches we'll use as model inputs
        self.Z_DIM = 16  # Number of slices in the z direction. Max value is 64 - Z_START
        self.Z_START = 25  # Offset of slices in the z direction
        self.DATA_DIR = "/kaggle/input/vesuvius-challenge-ink-detection"

def resize(img, SHARED_HEIGHT=RandomOpt().SHARED_HEIGHT):
    current_width, current_height = img.size
    aspect_ratio = current_width / current_height
    new_width = int(SHARED_HEIGHT * aspect_ratio)
    new_size = (new_width, SHARED_HEIGHT)
    img = img.resize(new_size)
    return img


def load_mask(split, index, DATA_DIR=RandomOpt().DATA_DIR):
    img = Image.open(f"{DATA_DIR}/mask.png").convert('1')
    img = resize(img)
    return torch.from_numpy(np.array(img))

def load_labels(split, index, DATA_DIR=RandomOpt().DATA_DIR):
    img = Image.open(f"{DATA_DIR}/inklabels.png")
    img = resize(img)
    return torch.from_numpy(np.array(img)).gt(0).float()

def load_volume(split, index, DATA_DIR=RandomOpt().DATA_DIR, Z_START=RandomOpt().Z_START, Z_DIM=RandomOpt().Z_DIM):
    # Load the 3d x-ray scan, one slice at a time
    z_slices_fnames = sorted(glob.glob(f"{DATA_DIR}/{split}/{index}/surface_volume/*.tif"))[Z_START:Z_START + Z_DIM]
    print(f"Number of files found: {len(z_slices_fnames)}")  # Add this line to print the number of files found
    z_slices = []
    for z, filename in  tqdm(enumerate(z_slices_fnames)):
        img = Image.open(filename)
        img = resize(img)
        z_slice = np.array(img, dtype="float32")
        z_slices.append(torch.from_numpy(z_slice))
    return torch.stack(z_slices, dim=0)

# Random choice of patches for training
def sample_random_location(shape, BUFFER=RandomOpt().BUFFER):
    a=BUFFER
    random_train_x = (shape[0] - BUFFER - 1 - a)*torch.rand(1)+a
    random_train_y = (shape[1] - BUFFER - 1 - a)*torch.rand(1)+a
    random_train_location = torch.stack([random_train_x, random_train_y])
    return random_train_location

def is_in_masked_zone(location, mask):
    return mask[location[0].long(), location[1].long()]

def is_in_val_zone(location, val_location, val_zone_size, BUFFER=RandomOpt().BUFFER):
    x = location[0]
    y = location[1]
    x_match = val_location[0] - BUFFER <= x <= val_location[0] + val_zone_size[0] + BUFFER
    y_match = val_location[1] - BUFFER <= y <= val_location[1] + val_zone_size[1] + BUFFER
    return x_match and y_match

class RandomPatchLocDataset(data.Dataset):
    def __init__(self, mask, val_location, val_zone_size):
        self.mask = mask
        self.val_location = val_location
        self.val_zone_size = val_zone_size
        self.sample_random_location_train = lambda x: sample_random_location(mask.shape)
        self.is_in_mask_train = lambda x: is_in_masked_zone(x, mask)

    def is_proper_train_location(self, location):
        return not is_in_val_zone(location, self.val_location, self.val_zone_size) and self.is_in_mask_train(location)

    def __len__(self):
        return 1280

    def __getitem__(self, index):
        # Generate a random patch
        # Ignore the index
        loc = self.sample_random_location_train(0)
        while not self.is_proper_train_location(loc):
            loc = self.sample_random_location_train(0)
        return loc.int().squeeze(1)

In [30]:
# ============= Model ==============
class ModelOpt:
    def __init__(self):
        # self.GPU_ID = '0'  
        self.Z_DIM = RandomOpt().Z_DIM
        self.BUFFER = RandomOpt().BUFFER
        self.SEED = 0
        self.BATCH_SIZE = 64
        self.LEARNING_RATE =1e-4
        self.TRAINING_EPOCH = 25
        self.LOG_DIR = '/kaggle/working'
        self.LOAD_VOLUME = [1, 2, 3]
        # Val
        self.VAL_LOC = (1300, 1000)
        self.VAL_SIZE = (300, 7000)

class RandomPatchModel():
    def __init__(self, opt = ModelOpt()):
        self.opt = opt
        self._setup_all()
        self.volume_list = [load_volume('train', i) for i in opt.LOAD_VOLUME]
        # Here volume: [Z_DIM, SHARED_HEIGHT, W_V1 + W_V2 + ...]
        self.volume = torch.cat(self.volume_list, dim=2)
        # Same for mask and label
        self.mask_list = [load_mask('train', i) for i in opt.LOAD_VOLUME]
        self.labels_list = [load_labels('train', i) for i in opt.LOAD_VOLUME]
        # [SHARED_HEIGHT, W_V1 + W_V2 + ...]
        self.labels = torch.cat(self.labels_list, dim=1)
        self.mask = torch.cat(self.mask_list, dim=1)

        self.net = create_model(opt.Z_DIM).to(self.device)

        # Dataset
        self.loc_datast = RandomPatchLocDataset(self.mask, val_location=opt.VAL_LOC, val_zone_size=opt.VAL_SIZE)
        self.loc_loader = data.DataLoader(self.loc_datast, batch_size=opt.BATCH_SIZE)
        # Val
        self.val_loc = []
        for x in range(opt.VAL_LOC[0], opt.VAL_LOC[0] + opt.VAL_SIZE[0], opt.BUFFER):
            for y in range(opt.VAL_LOC[1], opt.VAL_LOC[1] + opt.VAL_SIZE[1], opt.BUFFER):
                if is_in_masked_zone([torch.tensor(x),torch.tensor(y)], self.mask):
                    self.val_loc.append([[x, y]])
        print(f"======> Num Patches Val: {len(self.val_loc)}")


    def _setup_all(self):
        # random seed
        np.random.seed(self.opt.SEED)
        torch.manual_seed(self.opt.SEED)
        torch.cuda.manual_seed_all(self.opt.SEED)
        # torch
        # os.environ['CUDA_VISIBLE_DEVICES'] = self.opt.GPU_ID
        torch.backends.cudnn.enabled = True
        torch.backends.cudnn.benchmark = True
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        # Log
        self.log_dir = self.opt.LOG_DIR
        self.ckpt = os.path.join(self.log_dir)

    def get_subvolume(self, batch_loc, volume, labels):
        # batch_loc : [batch_size, 2]
        subvolume = []
        label = []
        for l in batch_loc:
            x = l[0]
            y = l[1]
            sv = volume[:, x - self.opt.BUFFER:x + self.opt.BUFFER, y - self.opt.BUFFER:y + self.opt.BUFFER]
            sv = sv / 65535.
            subvolume.append(sv)
            if labels is not None:
                lb = labels[x - self.opt.BUFFER:x + self.opt.BUFFER, y - self.opt.BUFFER:y + self.opt.BUFFER]
                lb = lb.unsqueeze(0)
                label.append(lb)
        # [batch, Z_DIM, BUFFER, BUFFER]
        subvolume = torch.stack(subvolume)
        # [batch, 1, BUFFER, BUFFER]
        if labels is not None:
            label = torch.stack(label)
        return subvolume, label

    def augment_train_data(self, subvolume, label):
        # Add Data augmentation here
        return subvolume, label

    def train_loop(self):
        print("=====> Begin training")
#         self.criterion = torch.nn.BCEWithLogitsLoss(reduction='mean')
        self.criterion = calc_loss
        self.optimizer = optim.Adam(self.net.parameters(), lr=self.opt.LEARNING_RATE)
        self.net.train()

        best_val_loss = 100
        best_val_acc = 0
        meter = AverageMeter()
        for epoch in range(self.opt.TRAINING_EPOCH):
            bar = tqdm(enumerate(self.loc_loader), total=len(self.loc_datast) / self.opt.BATCH_SIZE)
            bar.set_description_str(f"Epoch: {epoch}")
            for i, loc in bar:
                subvolume, label = self.get_subvolume(loc, self.volume, self.labels)
                loss = self._train_step(subvolume, label)
                meter.update(loss)
                bar.set_postfix_str(f"Avg loss: {np.round(meter.get_value(),3)}")

            val_loss, val_acc = self.validataion_loop()
            print(f"======> Val Loss:{np.round(val_loss,3)} | Val Acc:{np.round(val_acc,3)} ")
            if val_loss < best_val_loss and val_acc > best_val_acc:
                torch.save(self.net.state_dict(), os.path.join(self.ckpt, "best.pt"))
                print("======> Save best val model")

                best_val_loss = val_loss
                best_val_acc = val_acc



    def _train_step(self, subvolume, label):
        self.optimizer.zero_grad()
        # inputs: subvolume: [batch, Z_DIM, BUFFER, BUFFER]
        #         label: [batch, 1, BUFFER, BUFFER]
        outputs = self.net(subvolume.to(self.device))['out']
        loss = self.criterion(outputs, label.to(self.device))
        loss.backward()
        self.optimizer.step()
        return loss

    def validataion_loop(self):
        meter_loss = AverageMeter()
        meter_acc = AverageMeter()
        self.net.eval()
        for loc in self.val_loc:
            subvolume, label = self.get_subvolume(loc, self.volume, self.labels)
            outputs = self.net(subvolume.to(self.device))['out']
            loss = self.criterion(outputs, label.to(self.device))
            meter_loss.update(loss)
            pred = torch.sigmoid(outputs) > 0.5
            meter_acc.update(
                (pred == label.to(self.device)).sum(),
                int(torch.prod(torch.tensor(label.shape)))
            )
        self.net.train()
        return meter_loss.get_value(), meter_acc.get_value()

    def load_best_ckpt(self):
        self.net.load_state_dict(torch.load(os.path.join(self.ckpt, "best.pt")))


# For the metric
class AverageMeter(object):
    def __init__(self):
        self.sum = 0
        self.n = 0

    def update(self, x, n=1):
        self.sum += float(x)
        self.n += n

    def reset(self):
        self.sum = 0
        self.n = 0

    def get_value(self):
        if self.n:
            return self.sum / self.n
        return 0

In [31]:
# Define model
model = RandomPatchModel()

Number of files found: 0


0it [00:00, ?it/s]


RuntimeError: stack expects a non-empty TensorList