# Imports

In [25]:
import os

from tqdm import tqdm
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
from transformers import ViTMAEModel, AutoImageProcessor

from src.features.vitmae.dataset import init_downstream_datasets, init_dataloaders


# Device

In [26]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Image Processor

In [27]:
image_processor_checkpoint = r"facebook/vit-mae-base"
image_processor = AutoImageProcessor.from_pretrained(image_processor_checkpoint)

# Dataset

In [28]:
train_images_dir = r""
val_images_dir = r""
train_masks_dir = r""
val_masks_dir = r""

In [29]:
batch_size_train = 8
batch_size_val = 8
pin_memory = True
num_workers = 0

In [30]:
train_dataset, val_dataset = init_downstream_datasets(
    train_images_dir=train_images_dir,
    val_images_dir=val_images_dir,
    train_masks_dir=train_masks_dir,
    val_masks_dir=val_masks_dir,
    image_processor=image_processor
)

In [31]:
train_dataloader, val_dataloader = init_dataloaders(
    train_dataset=train_dataset,
    val_dataset=val_dataset,
    batch_size_train=batch_size_train,
    batch_size_val=batch_size_val,
    pin_memory=pin_memory,
    num_workers=num_workers
)

# Model

In [32]:
model = ViTMAEModel.from_pretrained(r"")
model = model.to(device)

Some weights of the model checkpoint at C:\Internship\ITMO_ML\CTCI\checkpoints\vit\vitmae_on_bubbles\run3\epoch_10_config\\ were not used when initializing ViTMAEModel: ['decoder.decoder_layers.6.attention.attention.query.weight', 'decoder.decoder_layers.1.intermediate.dense.weight', 'decoder.decoder_layers.7.attention.attention.value.weight', 'decoder.decoder_layers.4.intermediate.dense.bias', 'decoder.decoder_layers.1.attention.output.dense.weight', 'decoder.decoder_norm.weight', 'decoder.decoder_layers.1.attention.attention.query.weight', 'decoder.decoder_layers.6.intermediate.dense.bias', 'decoder.decoder_layers.0.attention.attention.value.weight', 'decoder.mask_token', 'decoder.decoder_layers.3.intermediate.dense.bias', 'decoder.decoder_embed.bias', 'decoder.decoder_layers.1.attention.attention.key.bias', 'decoder.decoder_layers.2.output.dense.bias', 'decoder.decoder_layers.3.output.dense.weight', 'decoder.decoder_layers.7.layernorm_before.bias', 'decoder.decoder_layers.2.layernor

In [33]:
class FocalLoss(nn.Module):
    def __init__(self, alpha=0.8, gamma=2):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma

    def forward(self, inputs, targets):
        targets = targets.float()
        # outputs = outputs.view(-1)
        # targets = targets.view(-1)

        criterion = nn.CrossEntropyLoss()

        ce_loss = criterion(inputs, targets)
        ce_exp = torch.exp(-ce_loss)
        focal_loss = (self.alpha * (1 - ce_exp) ** self.gamma * ce_loss).mean()
        return focal_loss

In [34]:
class Conv3x3(nn.Module):
    def __init__(self, in_channels, out_channels,
                 dropout=False, batch_norm=False, instance_norm=False, activation_func=None, bias=True):
        super(Conv3x3, self).__init__()
        self.net = self._init_net(in_channels, out_channels,
                                  dropout, batch_norm, instance_norm, bias, activation_func)

    def _init_net(self, in_channels, out_channels, dropout, batch_norm, instance_norm, bias, activation_func):
        
        net_list = [nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=bias)]
        if dropout:
            net_list.append(nn.Dropout2d(p=0.33))
        if batch_norm:
            net_list.append(nn.BatchNorm2d(out_channels))
        if instance_norm:
            net_list.append(nn.InstanceNorm2d(out_channels))
        if activation_func:
            net_list.append(activation_func)
        net = nn.Sequential(*net_list)
        return net
    
    def forward(self, x):
        return self.net(x)


In [35]:
class SegmentModel(torch.nn.Module):
    def __init__(self, vitmae):
        super().__init__()
        self.encoder = vitmae
        for param in self.encoder.parameters():
            param.requires_grad = False
        
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(in_channels=150, out_channels=150, kernel_size=2, stride=2),
            Conv3x3(in_channels=150, out_channels=100, batch_norm=False, activation_func=nn.LeakyReLU()), # 32 x 32
            nn.ConvTranspose2d(in_channels=100, out_channels=100, kernel_size=2, stride=2),
            Conv3x3(in_channels=100, out_channels=64, batch_norm=False, activation_func=nn.LeakyReLU()), # 64 x 64
            nn.ConvTranspose2d(in_channels=64, out_channels=64, kernel_size=2, stride=2),
            Conv3x3(in_channels=64, out_channels=32, batch_norm=False, activation_func=nn.LeakyReLU()),  # 128 x 128
            nn.ConvTranspose2d(in_channels=32, out_channels=32, kernel_size=2, stride=2),
            Conv3x3(in_channels=32, out_channels=16, batch_norm=False, activation_func=nn.LeakyReLU())  # 256 x 256
        )
        
        self.classifier = nn.Sequential(
            nn.Conv2d(in_channels=16, out_channels=1, kernel_size=1, padding=0),
            nn.Sigmoid()
        )
        
    def forward(self, inputs):
        batch_size, _, num_channels, height, width = inputs.data["pixel_values"].shape
        inputs.data["pixel_values"] = torch.reshape(
            inputs.data["pixel_values"],
            (batch_size, num_channels, height, width)
        )
        outputs = self.encoder(**inputs)
        last_hidden_state  = outputs.last_hidden_state
        features = torch.reshape(last_hidden_state, (batch_size, 150, 16, 16))
        features = self.decoder(features)
        
        mask = self.classifier(features)
        
        return mask
        
        

In [36]:
segmentator = SegmentModel(vitmae=model)

In [37]:
optimizer = torch.optim.Adam(model.parameters())

In [38]:
segmentation_criterion = nn.BCELoss()

In [39]:
save_dir = r""

In [40]:
def save_model(model, path):
    torch.save(model.state_dict(), path)

In [41]:
def downstream_task_train(
        model, optimizer, segmentation_criterion,
        train_dataloder, val_dataloader,
        device, save_dir,
        num_epochs=5
):
    history = {"train_batch": [], "train_epoch": [], "val_batch": [], "val_epoch": []}

    for epoch in range(num_epochs):
        model = model.to(device)

        print(f"Epoch {epoch + 1}:")
        epoch_history = {"train": [], "val": []}

        model.train()
        for inputs, masks in tqdm(train_dataloder):
            optimizer.zero_grad()

            inputs = inputs.to(device)
            masks = masks.to(device)
            
            outputs = segmentator(inputs)
            loss = segmentation_criterion(outputs, masks)
            loss.backward()
            optimizer.step()

            history["train_batch"].append(loss.item())
            epoch_history["train"].append(loss.item())

        epoch_train_loss = sum(epoch_history["train"]) / len(epoch_history["train"])
        print(f"Epoch train loss: {epoch_train_loss}")
        history["train_epoch"].append(epoch_train_loss)

        model.eval()
        for inputs, masks in tqdm(val_dataloader):
            inputs = inputs.to(device)
            masks = masks.to(device)
            
            outputs = segmentator(inputs)
            loss = segmentation_criterion(outputs, masks)

            history["val_batch"].append(loss.item())
            epoch_history["val"].append(loss.item())

        epoch_val_loss = sum(epoch_history["val"]) / len(epoch_history["val"])
        print(f"Epoch val loss: {epoch_val_loss}\n")
        history["val_epoch"].append(epoch_val_loss)

        save_model(model.to("cpu"), path=os.path.join(save_dir, f"epoch_{epoch + 1}.pt"))

    return history



In [None]:
history = downstream_task_train(
    model=segmentator,
    optimizer=optimizer,
    segmentation_criterion=segmentation_criterion,
    train_dataloder=train_dataloader,
    val_dataloader=val_dataloader,
    device=device,
    save_dir=save_dir,
    num_epochs=5
)

In [None]:
fig, ax = plt.subplots(nrows=1, ncols=2)
fig.set_size_inches(18, 6)

ax[0].plot(range(len(history["train_epoch"])), history["train_epoch"])
ax[0].set_title("Train loss")
ax[1].plot(range(len(history["val_epoch"])), history["val_epoch"])
ax[1].set_title("Val loss")

plt.show()

In [None]:
for inputs, masks in train_dataloader:
    inputs = inputs
    outputs = segmentator(inputs)
    mask = outputs[0].squeeze().detach().numpy()
    plt.imshow(mask)
    break