In [1]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import os
import lightning
import sklearn
import kornia.augmentation as K

In [2]:
train_df, valid_df = sklearn.model_selection.train_test_split(pd.read_csv('train.csv'), train_size=0.8)
test_df = pd.read_csv('test.csv')

In [3]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [35]:
class AgricultureDataset(torch.utils.data.Dataset):
    def __init__(self, df, img_dir, transform=None, target_transform=None):
        self.df = df
        self.img_dir = img_dir
        self.transform = transform
        self.target_transform = target_transform
        self.target_shape = (128, 128, 125)
        self.has_labels = 'label' in df.columns
        self.transform = transform
        self.transformations = torch.nn.Sequential(
            K.RandomHorizontalFlip(p=0.3),     
            K.RandomVerticalFlip(p=0.3),
            K.RandomAffine(degrees=5, translate=(0.05, 0.05), scale=(0.95, 1.05), p=0.5)
        )   

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.df.iloc[idx].id)
        img = np.load(img_path) 
        
        H, W, D = img.shape
        TH, TW, TD = self.target_shape

        padded = np.zeros(self.target_shape, dtype=np.float32)

        copy_H = min(H, TH)
        copy_W = min(W, TW)
        copy_D = min(D, TD)

        padded[:copy_H, :copy_W, :copy_D] = img[:copy_H, :copy_W, :copy_D]
        multi_spectral_image = torch.from_numpy(padded).permute(2, 0, 1)

        if self.transform:
            multi_spectral_image = self.transformations(multi_spectral_image)
            multi_spectral_image = multi_spectral_image.squeeze(dim=0)

        if self.has_labels:
            label = float(self.df.iloc[idx].label)
            return multi_spectral_image, label
        else:
            return multi_spectral_image, 0


In [36]:
train_dataset = AgricultureDataset(train_df, 'ot/ot', transform=True)
valid_dataset = AgricultureDataset(valid_df, 'ot/ot')
test_dataset = AgricultureDataset(test_df, 'ot/ot')

In [37]:
train_dataset[0][0].shape

torch.Size([125, 128, 128])

In [43]:
num_workers = 0
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size = 16, shuffle=True, num_workers=num_workers)
valid_dataloader = torch.utils.data.DataLoader(valid_dataset, batch_size = 16, shuffle=True, num_workers=num_workers)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size = 16, shuffle=False, num_workers=num_workers)

In [44]:
import torch
import torch.nn as nn

class MyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv_3_3_1 = nn.Conv3d(in_channels=125, out_channels=64, kernel_size=3, padding=1)
        self.batch_norm_1 = nn.BatchNorm3d(64)
        self.relu = nn.ReLU()
        self.conv_3_3_2 = nn.Conv3d(in_channels=64, out_channels=128, kernel_size=3, padding=1)
        self.batch_norm_2 = nn.BatchNorm3d(128)
        
        self.residual_conv = nn.Conv3d(in_channels=125, out_channels=128, kernel_size=1)
        
        self.final_relu = nn.ReLU()
        self.global_pool = nn.AdaptiveAvgPool3d((1, 1, 1))
        self.output = nn.Linear(128, 1)
    
    def forward(self, x):
        
        if x.dim() == 4:
            batch, channels, height, width = x.shape
            x = x.reshape(batch, channels, 1, height, width)

        x_res = self.residual_conv(x)
        
        x_forwarded = self.conv_3_3_1(x)
        x_forwarded = self.batch_norm_1(x_forwarded)
        x_forwarded = self.relu(x_forwarded)
        x_forwarded = self.conv_3_3_2(x_forwarded)
        x_forwarded = self.batch_norm_2(x_forwarded)
        
        x = x_res + x_forwarded 
        x = self.final_relu(x)
        x = self.global_pool(x)
        x = x.view(x.size(0), -1) 
        x = self.output(x)
        return x


In [56]:
class LitMyModel(lightning.LightningModule):
    def __init__(self, model):
        super().__init__()
        self.model = model.float()

    def training_step(self, batch, batch_idx):
        x, y = batch
        y = y.float()
        preds = self.model(x).squeeze(dim=1)
        loss = torch.nn.functional.mse_loss(preds, y)
        loss = loss.float()
        self.log_dict({'loss':loss})
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-3)
        return optimizer

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y = y.float()
        preds = self.model(x).squeeze(dim=1)
        loss = torch.nn.functional.mse_loss(preds, y)
        loss = loss.item()
        self.log_dict({'val_loss':loss})
        return {'loss': loss}

    def predict_step(self, batch, batch_idx, dataloader_idx=0):
        with torch.no_grad():
            x, _ = batch
            x = x.to(self.device)
            preds = self.model(x).squeeze(dim=1)
            return preds

        

In [57]:
model = MyModel()
trainableModel = LitMyModel(model)
callbacks = [lightning.pytorch.callbacks.EarlyStopping('val_loss'), lightning.pytorch.callbacks.ModelCheckpoint(monitor='val_loss')]

In [None]:
trainer = lightning.Trainer(accelerator=device,callbacks=callbacks)
trainer.fit(trainableModel, train_dataloader, valid_dataloader)

In [59]:
trainedModel = LitMyModel.load_from_checkpoint('lightning_logs/version_5/checkpoints/epoch=6-step=763.ckpt', model=MyModel())

In [60]:
all_preds = []
for batch in test_dataloader:
    preds = [int(x) for x in trainedModel.predict_step(batch, 0)]
    all_preds.extend(list(preds))

In [61]:
test_df['label'] = all_preds

In [62]:
test_df.to_csv('sub.csv',index=False)