In [1]:
import os
# os.chdir('/Users/aradinka/Documents/GitHub/koltiva/SSLTransformerRS')

from tqdm import tqdm
import torch
import torchvision.models as models

from dfc_dataset import DFCDataset
from metrics import ClasswiseAccuracy


class DoubleResNetSimCLRDownstream(torch.nn.Module):
    """concatenate outputs from two backbones and add one linear layer"""

    def __init__(self, base_model, out_dim):
        super(DoubleResNetSimCLRDownstream, self).__init__()

        self.resnet_dict = {"resnet18": models.resnet18,
                            "resnet50": models.resnet50,}
        

        self.backbone2 = self.resnet_dict.get(base_model)(pretrained=False, num_classes=out_dim)
        dim_mlp2 = self.backbone2.fc.in_features
        
        # If you are using multimodal data you can un-comment the following lines:
        # self.backbone1 = self.resnet_dict.get(base_model)(pretrained=False, num_classes=out_dim)
        # dim_mlp1 = self.backbone1.fc.in_features
        
        # add final linear layer
        self.fc = torch.nn.Linear(dim_mlp2, out_dim, bias=True)
        # self.fc = torch.nn.Linear(dim_mlp1 + dim_mlp2, out_dim, bias=True)

        # self.backbone1.fc = torch.nn.Identity()
        self.backbone2.fc = torch.nn.Identity()

    def _get_basemodel(self, model_name):
        try:
            model = self.resnet_dict[model_name]
        except KeyError:
            raise InvalidBackboneError(
                "Invalid backbone architecture. Check the config file and pass one of: resnet18 or resnet50")
        else:
            return model

    def forward(self, x):
        x2 = self.backbone2(x["s2"])

        # If you are using multimodal data you can un-comment the following lines and comment z = self.fc(x2):
        # x1 = self.backbone1(x["s1"])
        # z = torch.cat([x1, x2], dim=1)
        # z = self.fc(z)
     
        z = self.fc(x2)
        
        return z
    
    def load_trained_state_dict(self, weights):
        """load the pre-trained backbone weights"""
        
        # remove the MLP projection heads
        for k in list(weights.keys()):
            if k.startswith(('backbone1.fc', 'backbone2.fc')):
                del weights[k]
        
        log = self.load_state_dict(weights, strict=False)
        assert log.missing_keys == ['fc.weight', 'fc.bias']
        
        # freeze all layers but the last fc
        for name, param in self.named_parameters():
            if name not in ['fc.weight', 'fc.bias']:
                param.requires_grad = False

data_config = {
    'train_dir': '../data/data_disini',
    'val_dir': '../data/data_disini',
    'train_mode': 'validation', # 'test', 'validation'
    'val_mode': 'test', # 'test', 'validation'
    'num_classes': 9, # kepake
    'clip_sample_values': True, # clip (limit) values
    'train_used_data_fraction': 1, # fraction of data to use, should be in the range [0, 1]
    'val_used_data_fraction': 1,
    'image_px_size': 224,
    'cover_all_parts_train': True, # if True, if image_px_size is not 224 during training, we use a random crop of the image
    'cover_all_parts_validation': True, # if True, if image_px_size is not 224 during validation, we use a non-overlapping sliding window to cover the entire image
    'seed': 42,
}

train_config = {
    's1_input_channels': 2,
    's2_input_channels': 13,
    'finetuning': True, # If false, backbone layers is frozen and only the head is trained
    'classifier_lr': 3e-6,
    'learning_rate': 0.00001,
    'adam_betas': (0.9, 0.999), 
    'weight_decay': 0.001,
    'dataloader_workers': 4, # dipake
    'batch_size': 16, # dipake
    'epochs': 5, # diapke
    'target': 'dfc_label' # dipake
}

train_dataset = DFCDataset(
    data_config['train_dir'],
    mode=data_config['train_mode'],
    clip_sample_values=data_config['clip_sample_values'],
    used_data_fraction=data_config['train_used_data_fraction'],
    image_px_size=data_config['image_px_size'],
    cover_all_parts=data_config['cover_all_parts_train'],
    seed=data_config['seed'],
    
    add_cacao=True,
)
val_dataset = DFCDataset(
    data_config['val_dir'],
    mode=data_config['val_mode'],
    clip_sample_values=data_config['clip_sample_values'],
    used_data_fraction=data_config['val_used_data_fraction'],
    image_px_size=data_config['image_px_size'],
    cover_all_parts=data_config['cover_all_parts_validation'],
    seed=data_config['seed'],

    add_cacao=True
)

train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=train_config['batch_size'],
    shuffle=True,
    pin_memory=True,
    num_workers=train_config['dataloader_workers'],
)

val_loader = torch.utils.data.DataLoader(
    val_dataset,
    batch_size=train_config['batch_size'],
    shuffle=False,
    num_workers=train_config['dataloader_workers'],
)

  warn(f"Failed to load image Python extension: {e}")


In [4]:
# sanity check

import pandas as pd

val_obs = pd.read_csv("../data/data_disini/validation_observations.csv", header=None, names=["Season", "Scene", "ID"])
test_obs = pd.read_csv("../data/data_disini/test_observations.csv", header=None, names=["Season", "Scene", "ID"])

val_rubber = pd.read_csv("../data/data_disini/validation_observations_rubber.csv", header=None, names=["Season", "Scene", "ID"])
test_rubber = pd.read_csv("../data/data_disini/test_observations_rubber.csv", header=None, names=["Season", "Scene", "ID"])

val_dfc = [x for x in os.listdir("../data/data_disini/ROIs0000_validation/dfc_0") if x.split(".")[-1] == "tif"]
test_dfc = [x for x in os.listdir("../data/data_disini/ROIs0000_test/dfc_0") if x.split(".")[-1] == "tif"]

val_s2 = [x for x in os.listdir("../data/data_disini/ROIs0000_validation/s2_0") if x.split(".")[-1] == "tif"]
test_s2 = [x for x in os.listdir("../data/data_disini/ROIs0000_test/s2_0") if x.split(".")[-1] == "tif"]

p = f"""
## Validation ##
Loader: {len(train_dataset)}
Original csv: {len(val_obs["ID"].unique())}
New csv: {len(val_rubber["ID"].unique())}
DFC: {len(val_dfc)}
S2: {len(val_s2)}

## Test ##
Loader: {len(val_dataset)}
Original csv: {test_obs.shape[0]}
New csv: {test_rubber.shape[0]}
DFC: {len(test_dfc)}
S2: {len(test_s2)}
"""

print(p)


## Validation ##
Loader: 1053
Original csv: 986
New csv: 1053
DFC: 1053
S2: 1053

## Test ##
Loader: 5397
Original csv: 5128
New csv: 5397
DFC: 5397
S2: 5397



In [5]:
base_model = "resnet18"
num_classes = 9
model = eval('DoubleResNetSimCLRDownstream')(base_model, num_classes)

model.backbone2.conv1 = torch.nn.Conv2d(
    train_config['s2_input_channels'],
    64,
    kernel_size=(7, 7),
    stride=(2, 2),
    padding=(3, 3),
    bias=False,
)

device = torch.device("mps")
checkpoint = torch.load("../checkpoints/resnet18.pth", map_location=torch.device('mps'))
model.load_trained_state_dict(checkpoint["state_dict"])
model = model.to(device)

### Training ### 
if train_config['finetuning']:
    # train all parameters (backbone + classifier head)
    param_backbone = []
    param_head = []
    for p in model.parameters():
        if p.requires_grad:
            param_head.append(p)
        else:
            param_backbone.append(p)
        p.requires_grad = True
    # parameters = model.parameters()
    parameters = [
        {"params": param_backbone},  # train with default lr
        {
            "params": param_head,
            "lr": train_config['classifier_lr'],
        },  # train with classifier lr
    ]
    print("Finetuning")
else:
    # train only final linear layer for SSL methods
    print("Frozen backbone")
    parameters = list(filter(lambda p: p.requires_grad, model.parameters()))

optimizer = torch.optim.Adam(
    parameters,
    lr=train_config['learning_rate'],
    betas=train_config['adam_betas'],
    weight_decay=train_config['weight_decay'],
)
criterion = torch.nn.CrossEntropyLoss(ignore_index=255, reduction="mean").to(device)

Finetuning


In [6]:
### Training ### 
if train_config['finetuning']:
    # train all parameters (backbone + classifier head)
    param_backbone = []
    param_head = []
    for p in model.parameters():
        if p.requires_grad:
            param_head.append(p)
        else:
            param_backbone.append(p)
        p.requires_grad = True
    # parameters = model.parameters()
    parameters = [
        {"params": param_backbone},  # train with default lr
        {
            "params": param_head,
            "lr": train_config['classifier_lr'],
        },  # train with classifier lr
    ]
    print("Finetuning")
else:
    # train only final linear layer for SSL methods
    print("Frozen backbone")
    parameters = list(filter(lambda p: p.requires_grad, model.parameters()))

optimizer = torch.optim.Adam(
    parameters,
    lr=train_config['learning_rate'],
    betas=train_config['adam_betas'],
    weight_decay=train_config['weight_decay'],
)
criterion = torch.nn.CrossEntropyLoss(ignore_index=255, reduction="mean").to(device)

step = 0
for epoch in range(train_config['epochs']):
    # Model Training
    model.train()
    step += 1

    pbar = tqdm(train_loader)

    # track performance
    epoch_losses = torch.Tensor()
    metrics = ClasswiseAccuracy(data_config['num_classes'])

    for idx, sample in enumerate(pbar):

        if "x" in sample.keys():
            if torch.isnan(sample["x"]).any():
                # some s1 scenes are known to have NaNs...
                continue
        else:
            if torch.isnan(sample["s2"]).any():
                # some s1 scenes are known to have NaNs...
                continue

        # load input
        s2 = sample["s2"].to(device)
        img = {"s2": s2}
        
        # if you are using a unimodal dataset (s1 for example), you may un-comment the following lines:
        # s1 = sample["s1"].to(device)
        # img = {"s1": s1, "s2": s2}
        
        # load target
        y = sample[train_config['target']].long().to(device)
        
        # model output
        y_hat = model(img)
        
        # loss computation
        loss = criterion(y_hat, y)
        
        # backward step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # get prediction 
        _, pred = torch.max(y_hat, dim=1)

        epoch_losses = torch.cat([epoch_losses, loss[None].detach().cpu()])
        metrics.add_batch(y, pred)

        pbar.set_description(f"Epoch:{epoch}, Training Loss:{epoch_losses[-100:].mean():.4}")

    mean_loss = epoch_losses.mean()

    train_stats = {
            "train_loss": mean_loss.item(),
            "train_average_accuracy": metrics.get_average_accuracy(),
            "train_overall_accuracy": metrics.get_overall_accuracy(),
            **{
                "train_accuracy_" + k: v
                for k, v in metrics.get_classwise_accuracy().items()
            },
        }
    print(train_stats)

    if epoch % 2 == 0:  

        # Model Validation
        model.eval()
        pbar = tqdm(val_loader)

        # track performance
        epoch_losses = torch.Tensor()
        metrics = ClasswiseAccuracy(data_config['num_classes'])

        with torch.no_grad():
            for idx, sample in enumerate(pbar):
                if "x" in sample.keys():
                    if torch.isnan(sample["x"]).any():
                        # some s1 scenes are known to have NaNs...
                        continue
                else:
                    if torch.isnan(sample["s2"]).any():
                        # some s1 scenes are known to have NaNs...
                        continue
                # load input
                s2 = sample["s2"].float().to(device)
                img = {"s2": s2}

                # if you are using a unimodal dataset (s1 for example), you may un-comment the following lines:
                # s1 = sample["s1"].to(device)
                # img = {"s1": s1, "s2": s2}

                # load target
                y = sample[train_config['target']].long().to(device)

                # model output
                y_hat = model(img)

                # loss computation
                loss = criterion(y_hat, y)

                # get prediction 
                _, pred = torch.max(y_hat, dim=1)

                epoch_losses = torch.cat([epoch_losses, loss[None].detach().cpu()])
                metrics.add_batch(y, pred)


                pbar.set_description(f"Validation Loss:{epoch_losses[-100:].mean():.4}")

            mean_loss = epoch_losses.mean()

            val_stats = {
                "validation_loss": mean_loss.item(),
                "validation_average_accuracy": metrics.get_average_accuracy(),
                "validation_overall_accuracy": metrics.get_overall_accuracy(),
                **{
                    "validation_accuracy_" + k: v
                    for k, v in metrics.get_classwise_accuracy().items()
                },
            }

            print(f"Epoch:{epoch}", val_stats)
            
            # Save model checkpoint every 2 epochs 
            if epoch % 2 == 0:
                if epoch == 0:
                    continue

                save_weights_path = (
                    "checkpoints/" + "-".join(["classifier", "epoch", str(epoch)]) + ".pth"
                )
                torch.save(model.state_dict(), save_weights_path)

Finetuning


  warn(f"Failed to load image Python extension: {e}")
  warn(f"Failed to load image Python extension: {e}")
  warn(f"Failed to load image Python extension: {e}")
  warn(f"Failed to load image Python extension: {e}")
Epoch:0, Training Loss:1.987: 100%|██████████| 66/66 [00:44<00:00,  1.48it/s]


{'train_loss': 1.9866182804107666, 'train_average_accuracy': 0.2798751358910425, 'train_overall_accuracy': 0.3808167141500475, 'train_accuracy_class_3': 0.45698924731182794, 'train_accuracy_class_5': 0.06557377049180328, 'train_accuracy_class_4': 0.45098039215686275, 'train_accuracy_class_7': 0.48148148148148145, 'train_accuracy_class_1': 0.0, 'train_accuracy_class_2': 0.06862745098039216, 'train_accuracy_class_8': 0.9552238805970149, 'train_accuracy_class_0': 0.04, 'train_accuracy_class_6': 0.0}


  warn(f"Failed to load image Python extension: {e}")
  warn(f"Failed to load image Python extension: {e}")
  warn(f"Failed to load image Python extension: {e}")
  warn(f"Failed to load image Python extension: {e}")
Validation Loss:1.831: 100%|██████████| 338/338 [00:51<00:00,  6.55it/s]


Epoch:0 {'validation_loss': 1.9164905548095703, 'validation_average_accuracy': 0.3576772741513186, 'validation_overall_accuracy': 0.3989253288864184, 'validation_accuracy_class_7': 0.44741235392320533, 'validation_accuracy_class_0': 0.15963060686015831, 'validation_accuracy_class_4': 0.796849087893864, 'validation_accuracy_class_5': 0.25981308411214954, 'validation_accuracy_class_1': 0.0, 'validation_accuracy_class_2': 0.0, 'validation_accuracy_class_6': 0.0, 'validation_accuracy_class_3': 0.6, 'validation_accuracy_class_8': 0.9553903345724907}


  warn(f"Failed to load image Python extension: {e}")
  warn(f"Failed to load image Python extension: {e}")
  warn(f"Failed to load image Python extension: {e}")
  warn(f"Failed to load image Python extension: {e}")
Epoch:1, Training Loss:1.482: 100%|██████████| 66/66 [00:33<00:00,  1.98it/s]


{'train_loss': 1.4821525812149048, 'train_average_accuracy': 0.47181397293794775, 'train_overall_accuracy': 0.6638176638176638, 'train_accuracy_class_7': 0.8352272727272727, 'train_accuracy_class_8': 0.9552238805970149, 'train_accuracy_class_4': 0.7814569536423841, 'train_accuracy_class_3': 0.8586956521739131, 'train_accuracy_class_5': 0.2833333333333333, 'train_accuracy_class_1': 0.0, 'train_accuracy_class_2': 0.2692307692307692, 'train_accuracy_class_6': 0.0, 'train_accuracy_class_0': 0.2631578947368421}


  warn(f"Failed to load image Python extension: {e}")
  warn(f"Failed to load image Python extension: {e}")
  warn(f"Failed to load image Python extension: {e}")
  warn(f"Failed to load image Python extension: {e}")
Epoch:2, Training Loss:1.099: 100%|██████████| 66/66 [00:16<00:00,  4.05it/s]


KeyboardInterrupt: 