In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# Data

In [2]:
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

from torchvision.transforms import ToTensor

import os
import random
import datetime
from copy import deepcopy

import pandas as pd
import matplotlib.pyplot as plt
import numpy as np

from utils.dataset import CutOrPad, get_rgb

In [3]:
PASTIS24 = './data/PASTIS24/'
PASTIS9 = './data/PASTIS9/'

PATH = PASTIS24

In [4]:
files = os.listdir(PATH)
file = random.choice(files)

In [5]:
data = pd.read_pickle(PATH + file)

print(data.keys())
print('Image: ', data['img'].shape)
print('Labels: ', data['labels'].shape, data['labels'])
print('DOY: ', data['doy'].shape, data['doy'])

dict_keys(['img', 'labels', 'doy'])
Image:  (43, 10, 24, 24)
Labels:  (24, 24) [[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 1. 1. 1. 1.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 1. 1. 1. 1.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 1. 1. 1. 1.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 1. 1. 1. 1.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 1. 1. 1. 1.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 1. 1. 1. 1.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 1. 1. 1. 1.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 1. 1. 1. 1.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 1. 1. 1. 1.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 1. 1. 1. 1.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 1. 1. 1. 1.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 1. 1. 1. 1.]
 [0. 0. 0. 0. 0. 0. 0

In [6]:
class PASTIS(Dataset):
    def __init__(self, pastis_path):
        self.pastis_path = pastis_path

        self.file_names = os.listdir(self.pastis_path)[:1000]
        self.to_cutorpad = CutOrPad()


    def __len__(self):
        return len(self.file_names)


    def add_date_channel(self, img, doy):
        img = torch.cat((img, doy), dim=1)
        return img


    def normalize(self, img):
        C = img.shape[1]
        mean = img.mean(dim=(0, 2, 3)).to(torch.float32).reshape(1, C, 1, 1)
        std = img.std(dim=(0, 2, 3)).to(torch.float32).reshape(1, C, 1, 1)

        img = (img - mean) / std

        return img


    def __getitem__(self, idx):
        data = pd.read_pickle(os.path.join(self.pastis_path, self.file_names[idx]))

        data['img'] = data['img'].astype('float32')
        data['img'] = torch.tensor(data['img'])
        data['img'] = self.normalize(data['img'])
        T, C, H, W = data['img'].shape

        data['labels'] = data['labels'].astype('float32')
        data['labels'] = torch.tensor(data['labels'])
        # data['labels'] = F.one_hot(data['labels'].long(), num_classes=20)

        data['doy'] = data['doy'].astype('float32')
        data['doy'] = torch.tensor(data['doy'])
        data['doy'] = data['doy'].unsqueeze(1).unsqueeze(1).unsqueeze(1)
        data['doy'] = data['doy'].repeat(1, 1, H, W)

        data['img'] = self.add_date_channel(data['img'], data['doy'])

        del data['doy']
        data = self.to_cutorpad(data)
        del data['seq_lengths']


        return data['img'], data['labels']

In [7]:
data = PASTIS(PATH)
data.__len__()

1000

In [8]:
dataset = DataLoader(data, batch_size=4, shuffle=True)

### CLS

In [None]:
img, label = next(iter(dataset))
plt.imshow(get_rgb(img[0][:,:-1,:,:].numpy()))
print(label[0].unique())

### Seg

In [None]:
img, label = next(iter(dataset))

fix, axes = plt.subplots(1,2, figsize=(10,10))
axes[0].imshow(get_rgb(img[0][:,:-1,:,:].numpy()))
axes[1].imshow(label[0].numpy())

axes[0].set_title('img')
axes[1].set_title(f'{label.unique()}')

# Model

## Segmentation

### Temporal-Spatial Vision Transformer

In [9]:
class Segmentation(nn.Module):
    def __init__(self, img_height=24, img_width=24, in_channel=10,
                       patch_size=3, embed_dim=128, max_time=60,
                       num_classes=20, num_head=4, dim_feedforward=2048,
                       num_layers=4
                ):
        super().__init__()
        
        self.H = img_height
        self.W = img_width
        self.P = patch_size
        self.C = in_channel
        self.d = embed_dim
        self.T = max_time
        self.K = num_classes

        self.d_model = self.d
        self.num_head = num_head
        self.dim_feedforward = dim_feedforward
        self.num_layers = num_layers

        self.N = int(self.H * self.W // self.P**2)
        # self.n = int(self.N**0.5)
        self.nh = int(self.H / self.P)
        self.nw = int(self.W / self.P)

        self.encoderLayer = nn.TransformerEncoderLayer(d_model=self.d_model, nhead=self.num_head, dim_feedforward=self.dim_feedforward)
        self.encoder = nn.TransformerEncoder(self.encoderLayer, num_layers=self.num_layers)

        self.projection = nn.Conv3d(self.C, self.d, kernel_size=(1, self.P, self.P), stride=(1, self.P, self.P))
        '''
        def __init__():
            self.linear = nn.Linear(self.C*self.P**2, self.d)
        def forward():
            x = x.view(B, T, H // P, W // P, C*P**2)
            x = self.linear(x)
        '''

        self.temporal_emb = nn.Linear(366, self.d)
        self.temporal_cls_token = nn.Parameter(torch.randn(1, self.K, self.d)) # (1, K, d)
        self.temporal_transformer = self.encoder

        self.spatial_emb = nn.Parameter(torch.randn(1, self.N, self.d)) # (1, N, d)
        # self.spatial_cls_token = nn.Parameter(torch.randn(1, self.K, self.d)) # (1, K, d)
        self.spatial_transformer = self.encoder



        self.mlp_head = nn.Sequential(
            nn.LayerNorm(self.d),
            nn.Linear(self.d, self.P**2)
            )



    def forward(self, x):

        '''
        Tekenization
        '''
        # remove the timestamps (last channel) from the input
        x_sits = x[:, :, :-1]
        B, T, C, H, W = x_sits.shape # (B, T, C, H, W)
        
        x_sits = x_sits.reshape(B, C, T, H, W) # (B, C, T, H, W)
        x_sits = self.projection(x_sits) # (B, d, T, nw, nh)
        x_sits = x_sits.view(B, self.d, T, self.nh*self.nw) # (B, d, T, N)

        # Spatial Encoding (Positional Embeddings)
        # we dont add pos embedding here, cuz we need the pure data for the temporal encoder
        # x_sits = x_sits + self.pos_emb # (B, d, T, N) 

        x_sits = x_sits.permute(0,3,2,1) # (B, N, T, d)

        '''
        Temporal Encoding
        '''
        # in the last channel lies the timestamp
        xt = x[:, :, -1, 0, 0] # (B, T, C, H, W)
        # convert to one-hot
        # xt = (xt * 365.0001).to(torch.int64)
        xt = F.one_hot(xt.to(torch.int64), num_classes=366).to(torch.float32) # (B, T, 366)
        Pt = self.temporal_emb(xt) # (B, T, d)

        '''
        Temporal Encoder: cat(Z+Pt)
        '''
        x = x_sits + Pt.unsqueeze(1) # (B, N, T, d)
        temporal_cls_token = self.temporal_cls_token # (1, 1, K, d)
        temporal_cls_token = temporal_cls_token.repeat(B, self.N, 1, 1) # (B, N, K, d)
        x = torch.cat([temporal_cls_token, x], dim=2) # (B, N, K+T, d)
        x = x.view(B*self.N, self.K + T, self.d)
        x = self.temporal_transformer(x) # (B*N, K+T, d)
        x = x.view(B, self.N, self.K + T, self.d) # (B, N, K+T, d)
        x = x[:,:,:self.K] # (B, N, K, d)
        x = x.reshape(B, self.K, self.N, self.d) # (B, K, N, d)

        '''
        Spatial Encoding
        '''
        Ps = self.spatial_emb # (1, N, d)
        Ps = Ps.unsqueeze(1) # (1, 1, N, d)
        x = x + Ps # (B, K, N, d)

        # spatial_cls_token = self.spatial_cls_token # (1, K, d)
        # spatial_cls_token = spatial_cls_token.unsqueeze(2) # (1, K, 1, d)
        # spatial_cls_token = spatial_cls_token.repeat(B, 1, 1, 1) # (B, K, 1, d)
        # x = torch.cat([spatial_cls_token, x], dim=2) # (B, K, 1+N, d)

        x = x.view(B*(self.N), self.K, self.d) # (B*(N), K, d)
        x = self.spatial_transformer(x) # (B*(N), K, d)
        x = x.view(B, self.N, self.K, self.d) # (B, (N), K, d)

        # classes = x[:,0,:,:] # (B, K, d)

        # x = x[:,1:,:,:] # (B, N, K, d)

        x = self.mlp_head(x) # (B, N, K, P*P)

        x = x.view(B, self.N, self.K, self.P, self.P) # (B, N, K, P, P)

        x = x.view(B, self.K, self.N, self.P*self.P) # (B, K, N, P*P)

        x = x.view(B, self.K, H, W) # (B, K, H, W)


        return x

## Classification

### Temporal-Spatial Vision Transformer

In [None]:
class Classification(nn.Module):
    def __init__(self, img_height=9, img_width=9, in_channel=10,
                       patch_size=3, embed_dim=512, max_time=60,
                       num_classes=20, num_head=4, dim_feedforward=2048,
                       num_layers=4
                ):
        super().__init__()
        
        self.H = img_height
        self.W = img_width
        self.P = patch_size
        self.C = in_channel
        self.d = embed_dim
        self.T = max_time
        self.K = num_classes

        self.d_model = self.d
        self.num_head = num_head
        self.dim_feedforward = dim_feedforward
        self.num_layers = num_layers

        self.N = int(self.H * self.W // self.P**2)
        # self.n = int(self.N**0.5)
        self.nh = int(self.H / self.P)
        self.nw = int(self.W / self.P)




        # Parameters
        self.encoderLayer = nn.TransformerEncoderLayer(d_model=self.d_model, nhead=self.num_head, dim_feedforward=self.dim_feedforward)
        self.encoder = nn.TransformerEncoder(self.encoderLayer, num_layers=self.num_layers)

        self.projection = nn.Conv3d(self.C, self.d, kernel_size=(1, self.P, self.P), stride=(1, self.P, self.P))


        self.temporal_emb = nn.Linear(366, self.d)
        self.temporal_cls_token = nn.Parameter(torch.randn(1, self.K, self.d)) # (1, K, d)
        self.temporal_transformer = self.encoder


        self.spatial_emb = nn.Parameter(torch.randn(1, self.N, self.d)) # (1, N, d)
        # self.spatial_cls_token = nn.Parameter(torch.randn(1, self.K, self.d)) # (1, K, d)
        self.spatial_cls_token = nn.Parameter(torch.randn(1, 1, self.d)) # (1, 1, d)
        self.spatial_transformer = self.encoder
        self.mlp_head = nn.Sequential(nn.LayerNorm(self.d), nn.Linear(self.d, 1))


    def forward(self, x):

        '''
        Tekenization
        '''
        # remove the timestamps (last channel) from the input
        x_sits = x[:, :, :-1]
        B, T, C, H, W = x_sits.shape # (B, T, C, H, W)
        
        x_sits = x_sits.reshape(B, C, T, H, W) # (B, C, T, H, W)
        x_sits = self.projection(x_sits) # (B, d, T, nw, nh)
        x_sits = x_sits.view(B, self.d, T, self.nh*self.nw) # (B, d, T, N)

        # Spatial Encoding (Positional Embeddings)
        # we dont add pos embedding here, cuz we need the pure data for the temporal encoder
        # x_sits = x_sits + self.pos_emb # (B, d, T, N) 

        x_sits = x_sits.permute(0,3,2,1) # (B, N, T, d)



        '''
        Temporal Encoding
        '''
        # in the last channel lies the timestamp
        xt = x[:, :, -1, 0, 0] # (B, T, C, H, W)
        # convert to one-hot
        xt = F.one_hot(xt.to(torch.int64), num_classes=366).to(torch.float32) # (B, T, 366)
        Pt = self.temporal_emb(xt) # (B, T, d)


        '''
        Temporal Encoder: cat(Z+Pt)
        '''
        x = x_sits + Pt.unsqueeze(1) # (B, N, T, d)

        # CLS Token
        temporal_cls_token = self.temporal_cls_token # (1, K, d)
        temporal_cls_token = temporal_cls_token.repeat(B, self.N, 1, 1) # (B, N, K, d)
        x = torch.cat([temporal_cls_token, x], dim=2) # (B, N, K+T, d)
        x = x.view(B*self.N, self.K + T, self.d) # (B*N, K+T, d)

        x = self.temporal_transformer(x) # (B*N, K+T, d)
        x = x.view(B, self.N, self.K + T, self.d) # (B, N, K+T, d)
        x = x[:,:,:self.K,:] # (B, N, K, d)
        x = x.permute(0, 2, 1, 3) # (B, K, N, d)



        '''
        Spatial Encoding
        '''
        Ps = self.spatial_emb # (1, N, d)
        Ps = Ps.unsqueeze(1) # (1, 1, N, d)
        x = x + Ps # (B, K, N, d)
        
        # CLS Token
        spatial_cls_token = self.spatial_cls_token # (1, 1, d)
        spatial_cls_token = spatial_cls_token.repeat(B, self.K, 1) # (B, K, d)
        spatial_cls_token = spatial_cls_token.unsqueeze(2) # (B, K, 1, d)

        x = torch.cat([spatial_cls_token, x], dim=2) # (B, K, N + 1, d)
    
        x = x.view(B*(self.N+1), self.K, self.d) # (B*(N+1), K, d)
        x = self.spatial_transformer(x) # (B*(N+1), K, d)
        x = x.view(B, self.N+1, self.K, self.d) # (B, (N+1), K, d)
        classes = x[:,0,:,:] # (B, K, d)
        # print(classes)
        x = self.mlp_head(classes) # (B, K, 1)
        # print(x, x.shape)
        x = x.reshape(B, self.K) # (B, K)
        
        return x

### SVM

In [None]:
from sklearn import svm
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score

X = [x[0,:-1,:,:].reshape(10, 81) for x, y in data]
Y = [y for x, y in data]

X_train, X_test, y_train, y_test = train_test_split(X, Y, test_size=0.2, random_state=42)

clf = svm.SVC(kernel='linear')
clf.fit(X_train, y_train)


# Training

## Loss Function

### Classification

In [None]:
class FocalLoss(nn.Module):
    """
    Credits to  github.com/clcarwin/focal_loss_pytorch
    """
    def __init__(self, gamma=8, alpha=torch.ones(20), reduction='mean'):
        super(FocalLoss, self).__init__()
        self.gamma = gamma
        self.alpha = alpha
        if isinstance(alpha, (float, int)): self.alpha = torch.Tensor([alpha, 1 - alpha])
        if isinstance(alpha, list): self.alpha = torch.Tensor(alpha)
        self.reduction = reduction
        
    def forward(self, input, target):
        if input.dim() > 2:
            input = input.view(input.size(0), input.size(1), -1)  # N,C,H,W => N,C,H*W
            input = input.transpose(1, 2)  # N,C,H*W => N,H*W,C
            input = input.contiguous().view(-1, input.size(2))  # N,H*W,C => N*H*W,C
        target = target.view(-1, 1)

        logpt = F.log_softmax(input, dim=1)
        logpt = logpt.gather(1, target)
        logpt = logpt.view(-1)
        pt = Variable(logpt.data.exp())

        if self.alpha is not None:
            if self.alpha.type() != input.data.type():
                self.alpha = self.alpha.type_as(input.data)
            at = self.alpha.gather(0, target.data.view(-1))
            logpt = logpt * Variable(at)

        loss = -1 * (1 - pt) ** self.gamma * logpt
        if self.reduction is None:
            return loss
        elif self.reduction == "mean":
            return loss.mean()
        elif self.reduction == "sum":
            return loss.sum()
        else:
            raise ValueError(
                "FocalLoss: reduction parameter not in list of acceptable values [\"mean\", \"sum\", None]")

### Segmentation

In [10]:
class MaskedCrossEntropyLoss(torch.nn.Module):
    def __init__(self, mean=True):
        super(MaskedCrossEntropyLoss, self).__init__()
        self.mean = mean
    
    def forward(self, logits, ground_truth):
        if type(ground_truth) == torch.Tensor:
            target = ground_truth
            mask = None
        elif len(ground_truth) == 1:
            target = ground_truth[0]
            mask = None
        elif len(ground_truth) == 2:
            target, mask = ground_truth
        else:
            raise ValueError("ground_truth parameter for MaskedCrossEntropyLoss is either (target, mask) or (target)")
        
        if mask is not None:
            mask_flat = mask.reshape(-1, 1)  # (N*H*W x 1)
            nclasses = logits.shape[-1]
            logits_flat = logits.reshape(-1, logits.size(-1))  # (N*H*W x Nclasses)
            masked_logits_flat = logits_flat[mask_flat.repeat(1, nclasses)].view(-1, nclasses)
            target_flat = target.reshape(-1, 1)  # (N*H*W x 1)
            masked_target_flat = target_flat[mask_flat].unsqueeze(dim=-1).to(torch.int64)
        else:
            masked_logits_flat = logits.reshape(-1, logits.size(-1))  # (N*H*W x Nclasses)
            masked_target_flat = target.reshape(-1, 1).to(torch.int64)  # (N*H*W x 1)
        masked_log_probs_flat = torch.nn.functional.log_softmax(masked_logits_flat, dim=1)  # (N*H*W x Nclasses)
        masked_losses_flat = -torch.gather(masked_log_probs_flat, dim=1, index=masked_target_flat)  # (N*H*W x 1)
        if self.mean:
            return masked_losses_flat.mean()
        return masked_losses_flat

## Train

In [11]:
from torch import optim
import time

In [12]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Model
model = Segmentation(img_width=24, img_height=24, in_channel=10, patch_size=3, embed_dim=128, max_time=60)
model.to(device)

num_params = sum([p.numel() for p in model.parameters() if p.requires_grad == True])
print('Number of Parameters: ', num_params)

# Loss
criterion = MaskedCrossEntropyLoss()
# criterion = nn.CrossEntropyLoss()
# criterion = FocalLoss()
# criterion = nn.MSELoss()

# Optimizer
optimizer = optim.AdamW(model.parameters(), lr=0.1)

epoch = 10000
model.train()

for i in range(epoch):
  epoch_loss = 0

  t1 = time.time()
  for img, label in dataset:
    img = img.to(device)
    label = label.to(device).float()

    optimizer.zero_grad()
    output = model(img).permute(0, 2, 3, 1)

    # print(torch.argmax(output), torch.argmax(label))
    # print(output.shape, label.shape)
    # print(output, label)

    loss = criterion(output, label)
    epoch_loss += loss

    loss.backward()
    optimizer.step()


  if i % 10 == 0:
    torch.save({
              'epoch': i,
              'model_state_dict': model.state_dict(),
              'optimizer_state_dict': optimizer.state_dict(),
              'loss': loss,
              }, f'./weights/epoch_{i}.pt')
  t2 = time.time()
  print('Epoch: ', i, 'Loss: ', epoch_loss)

Number of Parameters:  3035913
Epoch:  0 Loss:  tensor(749.7170, device='cuda:0', grad_fn=<AddBackward0>)
Epoch:  1 Loss:  tensor(748.9868, device='cuda:0', grad_fn=<AddBackward0>)
Epoch:  2 Loss:  tensor(748.9849, device='cuda:0', grad_fn=<AddBackward0>)


KeyboardInterrupt: 

# Evaluation

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = Classification(img_width=9, img_height=9, in_channel=10, patch_size=3, embed_dim=128, max_time=60)
model.to(device)

In [None]:
checkpoint = torch.load('weights/epoch_99.pt')
model.load_state_dict(checkpoint['model_state_dict'])
# optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']

In [None]:
loss

# Playground