<a href="https://colab.research.google.com/github/amedyukhina/biomassters/blob/main/biomassters_all_times_unet_3D_all_data.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
%%writefile requirements.txt

pandas==1.3.5
scikit-learn==1.0.2
tqdm==4.64.0
numpy==1.21.6
torch
torchvision
scikit-image
matplotlib
pytorch_lightning
urllib3==1.25.4
monai==0.9.1
wandb
boto3==1.26.16
rasterio==1.2.0

In [None]:
!pip install -r requirements.txt

In [None]:
from google.colab import drive
import os
import pandas as pd
from tqdm import tqdm
import numpy as np
from skimage import io

from cachetools import cached, TTLCache, MRUCache

import torch
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import ToTensor
from torchvision.transforms import Compose, Normalize
import torch.nn.functional as F
from torch import nn
from torchvision.io import read_image
from torchvision import transforms
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor
import pytorch_lightning as pl
import warnings
from matplotlib import pyplot as plt
from tqdm.notebook import tqdm
import shutil
from scipy import ndimage
from monai.networks.layers import Norm
from monai.networks.nets import UNet, DynUNet
from monai.inferers import sliding_window_inference
import wandb

In [None]:
drive.mount('/content/gdrive')

In [None]:
with open('gdrive/MyDrive/Personal/wandb_apikey') as f:
    key = f.read()

os.environ['WANDB_API_KEY'] = key.rstrip('\n')

### Prepare list of chip IDs

In [None]:
feature_path = 'gdrive/MyDrive/biomassters/info/features_metadata.csv'
data_path = 'gdrive/MyDrive/biomassters/data/'
train_img_dir ='train_features'
test_img_dir = 'test_features'
label_dir = 'train_agbm'
model_checkpoint_path = 'gdrive/MyDrive/biomassters/models/'
MODE = 'S2'
nval = 100
PATCH_SIZE = 32

In [None]:
df = pd.read_csv(feature_path)
df_pred = df[df['split'] == 'test'].reset_index(drop=True)
pred_ids = np.unique(df_pred['chip_id'])
df = df[df['split'] == 'train'].reset_index(drop=True)
all_ids = np.unique(df['chip_id'])

np.random.seed(42)
np.random.shuffle(all_ids)
train_ids = all_ids[:-nval]
val_ids = all_ids[-nval:]
len(all_ids), len(train_ids), len(val_ids)

### Set up data loading

In [None]:
cache = MRUCache(maxsize=1000) 

In [None]:
import boto3
from botocore import UNSIGNED
from botocore.config import Config
import rasterio

# Our rasters contain no geolocation info, so silence this warning from rasterio
warnings.filterwarnings("ignore", category=rasterio.errors.NotGeoreferencedWarning)

BUCKET_NAME = 'drivendata-competition-biomassters-public-us'
os.environ["AWS_NO_SIGN_REQUEST"] = 'YES'
s3 = boto3.client('s3', config=Config(signature_version=UNSIGNED))

# @cached(cache)
def get_image_from_aws(fn):
  try:
    obj = s3.get_object(Bucket=BUCKET_NAME, Key=fn)
    with rasterio.open(obj['Body']) as src:
      img = src.read()
  except:
    img = np.zeros((len(MEANS), 256, 256))
  return img
     

In [None]:
# @cached(cache)
def get_image(fn):
    if os.path.exists(fn):
        img = io.imread(fn)
        if img.shape[-1] < 20:
            img = np.moveaxis(img, -1, 0)
    else:
        img = np.zeros((len(MEANS), 256, 256))
    return img

In [None]:
MEANS, STDS = np.load(data_path + rf'{MODE.lower()}_mean_std.npy')
MEANS = MEANS.reshape(-1,1,1)
STDS = STDS.reshape(-1,1,1)

### Define a dataset

In [None]:
class SentinelDataset2(Dataset):
    def __init__(self, chip_ids, img_dir, label_dir=None, 
                 transform=None):
        self.chip_ids = chip_ids
        self.img_dir = img_dir
        self.label_dir = label_dir
        self.transform = transform
        
    def __len__(self):
        return len(self.chip_ids)

    def __getitem__(self, idx):
        image = []
        for month in range(12):
            img = get_image_from_aws(os.path.join(self.img_dir, 
                                                  rf"{self.chip_ids[idx]}_{MODE}_{month:02d}.tif"))
            img = torch.tensor(img.astype(np.float32))
            image.append(img)
        image = torch.stack(image)
        t, c, h, w = image.shape
        
        if self.label_dir is not None:
            label = get_image_from_aws(os.path.join(self.label_dir, 
                                                    rf"{self.chip_ids[idx]}_agbm.tif"))
            label_filt = ndimage.median_filter(label[0], 3)
            label_filt = torch.tensor(label_filt.astype(np.float32)).unsqueeze(0)
            label = torch.tensor(label.astype(np.float32))
        else:
            label = label_filt = None

        
        if self.transform:
            image = self.transform(torch.concat([image.reshape(t*c, h, w), 
                                                 label, label_filt]))
            label = image[-2:-1]
            label_filt = image[-1:]
            # image = image[:-2].reshape(t, c, PATCH_SIZE, PATCH_SIZE)
            image = image[:-2].reshape(t, c, h, w)

        image = torch.stack([(img - torch.tensor(MEANS))/torch.tensor(STDS) 
                for img in image])
            
        return image, label, label_filt

In [None]:
train_transforms = transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.RandomVerticalFlip(),
        transforms.RandomRotation(degrees=180),
        # transforms.RandomCrop(PATCH_SIZE, pad_if_needed=True),
    ])

In [None]:
train_ds = SentinelDataset2(train_ids, train_img_dir, label_dir, 
                            transform=train_transforms
                            )
val_ds = SentinelDataset2(val_ids, train_img_dir, label_dir)
train_dataloader = DataLoader(train_ds, batch_size=4, 
                              shuffle=True, num_workers = 2)
valid_dataloader = DataLoader(val_ds, batch_size=4, 
                              shuffle=False, num_workers = 2)

In [None]:
torch.random.manual_seed(42)

In [None]:
%%time
imgs, labels, labels_filt = next(iter(train_dataloader))

In [None]:
imgs.shape, labels.shape

In [None]:
cols = 5
rows = 2
s = 3
fig, axes = plt.subplots(rows, cols, figsize=(s*cols, s*rows))
for img, ax in zip(imgs[0][10], axes.ravel()):
    ax.imshow(img.numpy())

In [None]:
# Show ground truth
plt.imshow(labels_filt[0][0].numpy())

### Define the model and the training pipeline

In [None]:
class Sentinel2Model(pl.LightningModule):
    def __init__(self, model):
        super().__init__()
        self.model = model

    def training_step(self, batch, batch_idx):
        img, _, label = batch
        # img, label, _  = batch
        predicted = self.model(img)
        loss = F.mse_loss(predicted, label)
        self.log("train/loss", loss)
        self.log("train/rmse", torch.sqrt(loss))
        return loss
    
    def validation_step(self, batch, batch_idx):
        img, label, label_filt = batch
        predicted = self.model(img)
        loss = F.mse_loss(predicted, label_filt)
        self.log("valid_loss", loss)
        self.log("valid_rmse", torch.sqrt(loss))
        self.log("valid_rmse_nonfilt", torch.sqrt(F.mse_loss(predicted, label)))
        return loss
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=0.02)
        scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9)
        return [optimizer], [scheduler]
    
    def forward(self, x):
        return self.model(x)

In [None]:
kernels = [[1, 5, 5], [3, 3, 3], [1, 3, 3], [3, 3, 3], [1, 3, 3]]
strides = [[1, 1, 1], [2, 2, 2], [1, 2, 2], [2, 2, 2], [1, 2, 2]]

class Unet3D(nn.Module):
    def __init__(self, kernels, strides):
        super(Unet3D, self).__init__()
        self.unet = DynUNet(
            spatial_dims=3,
            in_channels=len(MEANS),
            out_channels=1,
            kernel_size=kernels,
            strides=strides,
            upsample_kernel_size=strides[1:],
            norm_name="batch",
            deep_supervision=False,
            deep_supr_num=3,
        )
        self.conv = nn.Conv3d(12, 1, kernel_size=1, stride=1, padding=0)

    def forward(self, img):
        out = self.unet(img.transpose(1,2))
        return self.conv(out.transpose(1,2)).squeeze(1)


base_model = Unet3D(kernels, strides)

In [None]:
s2_model = Sentinel2Model(base_model)

In [None]:
wandb_logger = WandbLogger(project='BioMassters_all_timepoints')

In [None]:
checkpoint_callback = ModelCheckpoint(
     monitor='valid_rmse',
     dirpath=os.path.join(model_checkpoint_path, wandb.run.name),
     filename='{epoch:02d}-{valid_rmse:.2f}')
lr_monitor = LearningRateMonitor(logging_interval='step')

In [None]:
# Initialize a trainer
trainer = Trainer(
    accelerator="gpu",
    max_epochs=30,
    logger=[wandb_logger],
    callbacks=[checkpoint_callback, lr_monitor],
    log_every_n_steps=5
)

### Train

In [None]:
%%time
# Train the model ⚡
torch.random.manual_seed(42)
trainer.fit(s2_model, train_dataloaders=train_dataloader, 
            val_dataloaders=valid_dataloader)

### Show example predictions

In [None]:
model_name = wandb.run.name
# model_name = 'playful-field-16'
fn = os.listdir(os.path.join(model_checkpoint_path, model_name))[0]
s2_model.load_state_dict(torch.load(os.path.join(model_checkpoint_path, model_name, fn))['state_dict'])
s2_model.eval().cuda();

In [None]:
imgs, labels, _ = next(iter(valid_dataloader))
pred = s2_model(imgs.cuda())

In [None]:
s = 7
fig, axes = plt.subplots(1, 3, figsize=(s*3, s))
axes[0].imshow(ndimage.median_filter(labels[0].numpy()[0], 3))
axes[1].imshow(pred[0].cpu().detach().numpy()[0])
axes[2].imshow(ndimage.median_filter(labels[0].numpy()[0], 3) - pred[0].cpu().detach().numpy()[0])

### Prediction

In [None]:
%%time
output_dir = os.path.join(model_checkpoint_path, model_name, 'predicted')
os.makedirs(output_dir, exist_ok=True)
for chip_id in tqdm(pred_ids):
    image = []
    for month in range(12):
        img = get_image_from_aws(rf"test_features/{chip_id}_{MODE}_{month:02d}.tif")
        img = torch.tensor(img.astype(np.float32))
        img = (img - torch.tensor(MEANS))/torch.tensor(STDS)
        image.append(img)
    image = torch.stack(image)

    pred = s2_model(image.unsqueeze(0).cuda())
    img = pred.squeeze().cpu().detach().numpy()
    io.imsave(f"{output_dir}/{chip_id}_agbm.tif", img)

In [None]:
fn = os.path.join(output_dir, '../submission')
shutil.make_archive(fn, 'zip', output_dir)

In [None]:
from google.colab import files
files.download(fn + '.zip')