In [None]:
import torch
from timm.models.vision_transformer import vit_base_patch16_224
from torch import nn
!pip install -q lightly[timm]
from lightly.models import utils
from lightly.models.modules import MAEDecoderTIMM, MaskedVisionTransformerTIMM
from lightly.transforms import MAETransform
import torchvision.transforms as transforms
from matplotlib import pyplot as plt
from lightly.data import LightlyDataset

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!unzip -q '/content/drive/My Drive/au_opg/train.zip' -d '/content/au_opg'

In [None]:
class MAE(nn.Module):
    def __init__(self, vit):
        super().__init__()

        decoder_dim = 512
        self.mask_ratio = 0.75
        self.patch_size = vit.patch_embed.patch_size[0]

        self.backbone = MaskedVisionTransformerTIMM(vit=vit)
        self.sequence_length = self.backbone.sequence_length
        self.decoder = MAEDecoderTIMM(
            num_patches=vit.patch_embed.num_patches,
            patch_size=self.patch_size,
            embed_dim=vit.embed_dim,
            decoder_embed_dim=decoder_dim,
            decoder_depth=1,
            decoder_num_heads=16,
            mlp_ratio=4.0,
            proj_drop_rate=0.0,
            attn_drop_rate=0.0,
        )

    def forward_encoder(self, images, idx_keep=None):
        return self.backbone.encode(images=images, idx_keep=idx_keep)

    def forward_decoder(self, x_encoded, idx_keep, idx_mask):
        # build decoder input
        batch_size = x_encoded.shape[0]
        x_decode = self.decoder.embed(x_encoded)
        x_masked = utils.repeat_token(
            self.decoder.mask_token, (batch_size, self.sequence_length)
        )
        x_masked = utils.set_at_index(x_masked, idx_keep, x_decode.type_as(x_masked))

        # decoder forward pass
        x_decoded = self.decoder.decode(x_masked)

        # predict pixel values for masked tokens
        x_pred = utils.get_at_index(x_decoded, idx_mask)
        x_pred = self.decoder.predict(x_pred)
        return x_pred


    def forward(self, images):
        batch_size = images.shape[0]

        # Generate random token masks
        idx_keep, idx_mask = utils.random_token_mask(
            size=(batch_size, self.sequence_length-1),
            mask_ratio=self.mask_ratio,
            device=images.device,
        )

        # Encode the unmasked patches
        x_encoded = self.forward_encoder(images=images, idx_keep=idx_keep)
        # print('0', x_encoded.shape)

        # Decode to predict the masked patches
        x_pred = self.forward_decoder(x_encoded=x_encoded, idx_keep=idx_keep, idx_mask=idx_mask)
        # Get the original image patches
        patches = utils.patchify(images, self.patch_size)

        reconstructed_patches = torch.zeros(size=(batch_size, patches.shape[1], patches.shape[-1])).to('cuda')
        masked = reconstructed_patches.clone().to('cuda')
        # Place the original patches in the unmasked positions
        reconstructed_patches = utils.set_at_index(reconstructed_patches, idx_keep, utils.get_at_index(patches, idx_keep))
        masked = utils.set_at_index(masked, idx_keep, utils.get_at_index(patches, idx_keep))

        # Place the predicted patches in the masked positions
        reconstructed_patches = utils.set_at_index(
            reconstructed_patches, idx_mask, x_pred
        )
        # print('3', reconstructed_patches.shape)

        # Reconstruct the full image from the patches
        reconstructed_image = utils.unpatchify(reconstructed_patches, patch_size=self.patch_size)
        masked = utils.unpatchify(masked, patch_size=self.patch_size)

        # Get the ground truth for the masked patches
        target = utils.get_at_index(patches, idx_mask)

        return x_pred, target, reconstructed_image, masked



In [None]:
vit = vit_base_patch16_224()
model = MAE(vit)

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
# torch.autograd.detect_anomaly()
model = model.to(device)


In [None]:
# transform = MAETransform(min_scale=1, normalize={"mean": [0.5,0.5,0.5], "std":[0.5,0.5,0.5]})
# print(transform.transform)

In [None]:
def min_max_normalize(tensor):
    min_val = tensor.min()
    max_val = tensor.max()
    return (tensor - min_val) / (max_val - min_val)

transform = transforms.Compose([
    transforms.Resize((256,256)),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    min_max_normalize,
    transforms.Normalize([0.5,0.5,0.5], [0.5,0.5,0.5]),
])
transform = transforms.Compose([
    transforms.Resize((256,256)),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

In [None]:
dataset = LightlyDataset("/content/au_opg/train", transform=transform)

In [None]:
dataloader = torch.utils.data.DataLoader(
    dataset,
    batch_size=256,
    shuffle=True,
    drop_last=True,
)
print(len(dataloader), 'batches')


Sanity check

In [None]:
first_batch = next(iter(dataloader))
model = model.to(device)
output = model(first_batch[0].to(device))

Train

In [None]:
fig, ax = plt.subplots(1,3,figsize=(10,10))
# pil_rev = transforms.Compose([
#     transforms.Normalize([-1,-1,-1], [2,2,2]),
#     transforms.ToPILImage()
# ])
pil_rev = transforms.Compose([transforms.Normalize(mean=[0., 0., 0.],
                                                    std=[1/0.229, 1/0.224, 1/0.225]),
                               transforms.Normalize(mean=[-0.485, -0.456, -0.406],
                                                    std=[1., 1., 1.]),
                               ])
img = 6
ax[0].imshow(pil_rev(output[2][img].cpu().detach()).permute(1, 2, 0))
ax[1].imshow(pil_rev(output[3][img].cpu().detach()).permute(1, 2, 0))
ax[2].imshow(pil_rev(first_batch[0][img].cpu().detach()).permute(1, 2, 0))


In [None]:
criterion = nn.MSELoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=1.5e-4)

In [None]:
print("Starting Training")
for epoch in range(100):
    total_loss = 0
    for batch in dataloader:
        images = batch[0].to(device)
        # print(views.shape)
        images = images.to(device)  # views contains only a single view
        predictions, targets, _, _= model(images)
        loss = criterion(predictions, targets)
        total_loss += loss.detach()
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
    avg_loss = total_loss / len(dataloader)
    print(f"epoch: {epoch:>02}, loss: {avg_loss:.5f}")

Visualization

In [None]:
from torchvision import transforms
invTrans = transforms.Compose([transforms.Normalize(mean=[0., 0., 0.],
                                                    std=[1/0.229, 1/0.224, 1/0.225]),
                               transforms.Normalize(mean=[-0.485, -0.456, -0.406],
                                                    std=[1., 1., 1.]),
                               ])
model.eval()
original_images, generated_images = [], []
num_images = 5
model = model.to(device)
train_iter = iter(dataloader)
images = next(train_iter)[0][0]
for i in range(num_images):
    x = invTrans(images[i])
    original_images.append(x.permute(1, 2, 0).to('cpu').numpy())
    x = x.unsqueeze(0).to(device)
    yHat = model(x)
    print(yHat[0].shape)
    yHat = invTrans(yHat[2].squeeze(0))
    print(yHat.shape)
    generated_images.append(yHat.permute(1, 2, 0).detach().to('cpu').numpy())


In [None]:
import matplotlib.pyplot as plt
first_channel, second_channel, third_channel = original_images[0][:, :, 0], original_images[0][:, :, 1], original_images[0][:, :, 2]
plt.subplot(1, 3, 1)
plt.imshow(first_channel, cmap='gray')
plt.subplot(1, 3, 2)
plt.imshow(second_channel, cmap='gray')
plt.subplot(1, 3, 3)
plt.imshow(third_channel, cmap='gray')

In [None]:
plt.figure(figsize=(10, 5))
for i in range(num_images):
    plt.subplot(2, num_images, i + 1)
    plt.imshow(original_images[i])
    plt.axis('off')
    plt.subplot(2, num_images, i + num_images + 1)
    plt.imshow(generated_images[i])
    plt.axis('off')