In [None]:
! mkdir ~/.kaggle
! cp kaggle.json ~/.kaggle/
! chmod 600 ~/.kaggle/kaggle.json

In [None]:
! kaggle competitions download -c cs-480-2024-spring

In [None]:
! unzip cs-480-2024-spring.zip -d .

In [None]:
! pip install --quiet "ipython[notebook]==7.34.0, <8.17.0" "setuptools>=68.0.0, <68.3.0" "tensorboard" "lightning>=2.0.0" "urllib3" "torch==2.3.0" "matplotlib" "pytorch-lightning>=1.4, <2.1.0" "seaborn" "torchvision" "torchmetrics>=0.7, <1.3" "matplotlib>=3.0.0, <3.9.0"

In [None]:
import os
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from PIL import Image
import torch.utils.data as data
from torch.utils.data import Dataset
import numpy as np
from tqdm.notebook import tqdm
import lightning as L
from lightning.pytorch.callbacks import LearningRateMonitor, ModelCheckpoint
import torchvision
from torchvision import transforms
import matplotlib.pyplot as plt

BATCH_SIZE = 256

device = torch.device(
    "cuda:0") if torch.cuda.is_available() else torch.device("cpu")

if device == torch.device("cuda:0"):
    print('GPU is detected')
else:
    print('GPU is not detected.')

In [None]:
resnet = torchvision.models.resnet152(pretrained=True).to(device)
dino = torch.hub.load('facebookresearch/dinov2',
                      'dinov2_vitg14_reg').to(device)

resnet_feature_size = resnet.fc.in_features
resnet.fc = nn.Identity()
dino_feature_size = 1536
inception_feature_size = 1536

for param in resnet.parameters():
    param.requires_grad = False

for param in dino.parameters():
    param.requires_grad = False

resnet.eval()
dino.eval()

In [None]:
class PlantDataset(Dataset):
    def __init__(self, csv_file, img_dir, feature_extractors, transform, inference_transform):
        self.data = pd.read_csv(csv_file)
        self.img_dir = img_dir
        self.transform = transform
        self.inference_transform = inference_transform
        self.image_cache = {}

        self._cache_images(feature_extractors)

    def _cache_images(self, feature_extractors):
        print("Caching images...")
        for idx in tqdm(range(len(self.data)), desc="Loading images"):
            img_id = self.data.iloc[idx, 0]
            img_name = f"{img_id}.jpeg"
            img_path = os.path.join(self.img_dir, img_name)

            raw_image = Image.open(img_path).convert('RGB')

            self.image_cache[idx] = {
                'image': raw_image,
                'features': None
            }

        for model in feature_extractors:
            model_name = model.__class__.__name__
            print(f"Extracting features using {model_name}...")
            
            dataset = data.TensorDataset(torch.arange(len(self.data)))
            dataloader = data.DataLoader(dataset, batch_size=512, shuffle=False, num_workers=4)
            
            all_features = []
            
            for batch_idx in tqdm(dataloader, desc=model_name):
                batch_images = torch.stack([
                    self.inference_transform(self.image_cache[idx.item()]['image'])
                    for idx in batch_idx[0]
                ]).to(device)

                with torch.no_grad():
                    batch_features = model(batch_images).cpu().numpy()
                
                # print(batch_features.shape)
                all_features.append(batch_features)
            
            all_features = np.concatenate(all_features, axis=0)
            print(all_features.shape)
            
            for idx, features in enumerate(all_features):
                if self.image_cache[idx]['features'] is None:
                    self.image_cache[idx]['features'] = features
                else:
                    self.image_cache[idx]['features'] = np.concatenate(
                        (self.image_cache[idx]['features'], features), axis=0)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_id = self.data.iloc[idx, 0]
        image_data = self.image_cache[idx]
        image = image_data['image']
        image = self.transform(image)
        image_features = image_data['features']

        traits = self.data.iloc[idx, 1:164].values.astype(float)
        target_traits = self.data.iloc[idx, 164:].values.astype(float)

        return image, torch.tensor(traits, dtype=torch.float32), image_features, torch.tensor(target_traits, dtype=torch.float32), img_id


# Image need to be multiple of 14
# So we need to crop to 126x126
train_transform = transforms.Compose(
    [
        transforms.RandomHorizontalFlip(),
        transforms.ColorJitter(
            brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
        transforms.RandomResizedCrop(
            (128, 128), scale=(0.8, 1.0), ratio=(0.9, 1.1)),
        transforms.Resize((128, 128)),
        transforms.CenterCrop((126, 126)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[
                             0.229, 0.224, 0.225]),
    ]
)

test_transform = transforms.Compose(
    [
        transforms.Resize((128, 128)),
        transforms.CenterCrop((126, 126)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[
                             0.229, 0.224, 0.225]),
    ]
)

train_dataset = PlantDataset(
    csv_file='data/train.csv',
    img_dir='data/train_images',
    feature_extractors=[resnet, dino],
    transform=train_transform,
    inference_transform=test_transform,
)

test_dataset = PlantDataset(
    csv_file='data/test.csv',
    img_dir='data/test_images',
    feature_extractors=[resnet, dino],
    transform=test_transform,
    inference_transform=test_transform,
)

In [None]:
def mean_r2(y_true: torch.Tensor, y_pred: torch.Tensor) -> torch.Tensor:
    y_true_mean = torch.mean(y_true, dim=0)
    ss_tot = torch.sum((y_true - y_true_mean)**2, dim=0)
    ss_res = torch.sum((y_pred - y_true)**2, dim=0)
    r2 = 1 - ss_res / (ss_tot + torch.finfo(y_true.dtype).eps)
    return torch.mean(r2)

train_set, val_set = data.random_split(train_dataset, [int(
    0.8 * len(train_dataset)), len(train_dataset) - int(0.8 * len(train_dataset))])

train_loader = data.DataLoader(
    train_set, batch_size=BATCH_SIZE, shuffle=True, drop_last=True, pin_memory=True, num_workers=4)
val_loader = data.DataLoader(
    val_set, batch_size=BATCH_SIZE, shuffle=False, drop_last=True, num_workers=4)
test_loader = data.DataLoader(
    test_dataset, batch_size=BATCH_SIZE, shuffle=False, drop_last=False, num_workers=4)

def calculate_data_statistics(data, device):
    tensor_data = torch.tensor(data, dtype=torch.float32)
    return {
        'min': tensor_data.min(0, keepdim=True)[0].to(device).detach(),
        'max': tensor_data.max(0, keepdim=True)[0].to(device).detach(),
        'mean': tensor_data.mean(0, keepdim=True).to(device).detach(),
        'std': tensor_data.std(0, unbiased=False, keepdim=True).to(device).detach()
    }

def prepare_dataset_statistics(train_dataset, device):
    training_targets = train_dataset.data.iloc[:, 164:].values.astype(float)
    training_traits = train_dataset.data.iloc[:, 1:164].values.astype(float)

    target_stats = calculate_data_statistics(training_targets, device)
    trait_stats = calculate_data_statistics(training_traits, device)

    return {
        'target': target_stats,
        'trait': trait_stats
    }

dataset_stats = prepare_dataset_statistics(train_dataset, device)

# Let's visualize some examples
NUM_IMAGES = 4
images = torch.stack([train_set[idx][0] for idx in range(NUM_IMAGES)], dim=0)

img_grid = torchvision.utils.make_grid(
    images, nrow=4, normalize=True, pad_value=0.9)
img_grid = img_grid.permute(1, 2, 0)

plt.figure(figsize=(8, 8))
plt.title(f"Image examples of the plants dataset")
plt.imshow(img_grid)
plt.axis("off")
plt.show()
plt.close()

In [None]:
class MainPlantModel(nn.Module):
    def __init__(self, dropout=0.5):
        super(MainPlantModel, self).__init__()

        self.trait_head = nn.Sequential(
            nn.Linear(163, 256),
            nn.LazyBatchNorm1d(),
            nn.LeakyReLU(),
            nn.Dropout(dropout),
            nn.LazyLinear(64),
            nn.LazyBatchNorm1d(),
            nn.LeakyReLU(),
        )

        self.resnet_head = nn.Sequential(
            nn.Linear(resnet_feature_size, 256),
            nn.LazyBatchNorm1d(),
            nn.LeakyReLU(),
            nn.Dropout(dropout),
            nn.LazyLinear(64),
            nn.LazyBatchNorm1d(),
            nn.LeakyReLU(),
        )

        self.dino_head = nn.Sequential(
            nn.Linear(dino_feature_size, 256),
            nn.LazyBatchNorm1d(),
            nn.LeakyReLU(),
            nn.Dropout(dropout),
            nn.LazyLinear(64),
            nn.LazyBatchNorm1d(),
            nn.LeakyReLU(),
        )

        self.combined_layers = nn.Sequential(
            nn.Linear(64*3, 256),
            nn.LazyBatchNorm1d(),
            nn.LeakyReLU(),
            nn.Dropout(dropout),
            nn.LazyLinear(64),
            nn.LazyBatchNorm1d(),
            nn.LeakyReLU(),
            nn.LazyLinear(6),
        )

    def forward(self, image_features: torch.Tensor, traits: torch.Tensor) -> torch.Tensor:
        # take the first 'resnet_feature_size' from image_features as resnet features
        # take the last 'dino_feature_size' from image_features as dino features
        resnet_features = image_features[:, :resnet_feature_size]
        dino_features = image_features[:, -dino_feature_size:]
        resnet_features = self.resnet_head(resnet_features)
        dino_features = self.dino_head(dino_features)
        traits_features = self.trait_head(traits)
        combined_features = torch.cat(
            (resnet_features, dino_features, traits_features), dim=1)
        pred = self.combined_layers(combined_features)
        return pred


class MainPlantModule(L.LightningModule):
    def __init__(self):
        super(MainPlantModule, self).__init__()

        self.torch_module = MainPlantModel()
        image_features = next(iter(train_loader))[2]
        traits = next(iter(train_loader))[1]
        self.example_input_array = (
            image_features.to(device), traits.to(device))

    def forward(self, image_features, traits):
        normalized_traits = (
            traits - dataset_stats['trait']['mean']) / dataset_stats['trait']['std']
        pred = self.torch_module(image_features, normalized_traits)
        unscaled_pred = pred * \
            dataset_stats['target']['std'] + dataset_stats['target']['mean']
        return unscaled_pred

    def configure_optimizers(self):
        optimizer = optim.AdamW(self.parameters(), lr=1e-1, weight_decay=0.01)
        scheduler = torch.optim.lr_scheduler.MultiStepLR(
            optimizer, milestones=[30, 60, 80], gamma=0.2)
        return {"optimizer": optimizer, "lr_scheduler": {"scheduler": scheduler, "interval": "epoch", "monitor": "train_loss"}}

    def _calculate_loss(self, batch, mode='train'):
        _, traits, image_features, target, _ = batch
        outputs = self.forward(image_features, traits)
        scaled_outputs = (
            outputs - dataset_stats['target']['mean']) / dataset_stats['target']['std']
        scaled_target = (
            target - dataset_stats['target']['mean']) / dataset_stats['target']['std']
        loss = F.mse_loss(scaled_outputs, scaled_target)

        self.log(f'{mode}_loss', loss, prog_bar=True)

        if mode != 'train':
            self.log(f'{mode}_r2', mean_r2(
                scaled_target, scaled_outputs), prog_bar=True)

        return loss

    def training_step(self, batch, _):
        loss = self._calculate_loss(batch)
        return loss

    def validation_step(self, batch, _):
        self._calculate_loss(batch, mode='val')

    def test_step(self, batch, _):
        self._calculate_loss(batch, mode='test')

    def predict_step(self, batch, _):
        _, traits, image_features, _, _ = batch
        return self.forward(image_features, traits)


main_module = MainPlantModule()

In [None]:
# module = torch.compile(module) # doesn't work on GTX 1080 :(
main_trainer = L.Trainer(default_root_dir="saved_models",
                    accelerator="auto",
                    devices=1,
                    max_epochs=100,
                    log_every_n_steps=5,
                    callbacks=[
                        ModelCheckpoint(monitor="val_r2",
                                        save_top_k=1,
                                        mode="max"),
                        LearningRateMonitor("epoch"),
                    ],)
# Enable computation graph plotting in TensorBoard
main_trainer.logger._log_graph = True
# Disable default hyperparameter logging
main_trainer.logger._default_hp_metric = None

main_trainer.fit(main_module, train_loader, val_loader)
results = main_trainer.test(main_module, dataloaders=val_loader, verbose=False)
print("Results", results)

In [None]:
# Opens tensorboard in notebook. Adjust the path to your CHECKPOINT_PATH!
%load_ext tensorboard
%tensorboard --logdir /content/saved_models/lightning_logs

In [None]:
# predictions = trainer.predict(main_module, dataloaders=test_loader, ckpt_path='saved_models/lightning_logs/version_14/checkpoints/epoch=36-step=4995.ckpt')
predictions = main_trainer.predict(main_module, dataloaders=test_loader)
concat_predictions = torch.cat(predictions, dim=0).detach().numpy()
ids = test_dataset.data.iloc[:, 0].values.tolist()
df_predictions = pd.DataFrame(concat_predictions, columns=['X4', 'X11', 'X18', 'X50', 'X26', 'X3112'])
df_predictions.insert(0, 'id', ids)
df_predictions.to_csv('submission.csv', index=False)

In [None]:
! kaggle competitions submit -c cs-480-2024-spring -f submission.csv -m "submission message"