In [None]:
import torch
import torch.nn as nn
import numpy as np
from data_loader import data_load
from att2d_unet import build_unet
from ramps import get_current_consistency_weight, update_ema_variables
from glob import glob
import tensorflow as tf
from time import time
from datetime import datetime
from monai.data import decollate_batch
from monai.losses import DiceLoss
from monai.metrics import DiceMetric
from monai.transforms import (
    Compose,
    AsDiscrete,
    EnsureType,
)
import os
from torch.utils.tensorboard import SummaryWriter
import rioxarray
import numpy as np
from sklearn.metrics import average_precision_score
from metrics_evaluator import MetricsEvaluator

writer = SummaryWriter()


In [None]:
def load_dataset(folder, label=True, shuffle_size=500):
    features = BANDS + TARGETS if label else BANDS
    tf_files = glob(f"{folder}/*.gz")
    columns = [
        tf.io.FixedLenFeature(
            shape=KERNEL_SHAPE if label else BUFFERED_SHAPE, dtype=tf.float32
        )
        for _feature in features
    ]
    description = dict(zip(features, columns))
    data_func = data_load(
        tf_files,
        BANDS,
        description,
        response=TARGETS,
        shuffle_size=shuffle_size,
    )
    data = (
        data_func.get_training_dataset()
        if label
        else data_func.get_pridiction_dataset()
    )
    return data


BANDS = ["blue", "green", "red", "nir", "swir1", "swir2", "ndvi", "nirv"]
KERNEL_SHAPE = [256, 256]
KERNEL_BUFFER = [128, 128]
X_BUFFER, Y_BUFFER = [buffer // 2 for buffer in KERNEL_BUFFER]
X_BUFFERED, Y_BUFFERED = (X_BUFFER + KERNEL_SHAPE[0]), (Y_BUFFER + KERNEL_SHAPE[1])
BUFFERED_SHAPE = [
    kernel + buffer for kernel, buffer in zip(KERNEL_SHAPE, KERNEL_BUFFER)
]
TARGETS = ["cropland"]
NCLASS = 2
shuffle_size = 500
model_folder = f"/bess23/huaize/semi-supervised/models/"
pred_folder = (
    "/bess23/huaize/semi-supervised/data/unlabeled/Colorado/earthengine_Colorado"
)
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
batch = 128
# batch_test = batch * 5

In [None]:
def power_jaccard_loss(y_true, y_pred, power=2):
    '''
    Jaccard loss with power
    This loss is used in the paper "A Review on Deep Learning Techniques Applied to Semantic Segmentation"
    Input:
        y_true: ground truth
        y_pred: prediction
        power: power of the jaccard loss
    Output:
        loss: jaccard loss
    '''
    y_true = tf.cast(y_true, tf.float32)
    y_pred = tf.cast(y_pred, tf.float32)
    intersection = tf.reduce_sum(y_true * y_pred, axis=[1, 2, 3])
    union = tf.reduce_sum(y_true + y_pred, axis=[1, 2, 3])
    return 1 - tf.reduce_mean((intersection + 1) / (union - intersection + 1))

In [None]:
class dice_loss(nn.Module):
    def __init__(self, class_weights=None):
        super(dice_loss, self).__init__()
        self.class_weights = class_weights

    def forward(self, prediction, target, to_onehot_y=False):
        epsilon=torch.tensor(1e-8).to(target.device)
        if to_onehot_y:
            target = target.long()
            target_onehot = torch.zeros_like(prediction)
            target_onehot.scatter_(1, target, 1)
            target = target_onehot
            
        if self.class_weights is None:
            self.class_weights = torch.ones(target.shape[1]).to(target.device)
        intersection = 2 * torch.sum(prediction * target, dim=(0, 2, 3)) + epsilon
        
        union = (
            torch.sum(prediction, dim=(0, 2, 3))
            + torch.sum(target, dim=(0, 2, 3))
            + epsilon
        )
        class_dice_scores = intersection / union
        weighted_dice_scores = class_dice_scores * self.class_weights
        loss = 1 - torch.mean(weighted_dice_scores)

        return loss

In [None]:
dataset = load_dataset(
    "/bess23/huaize/semi-supervised/data/labeled/train",
    label=True,
    shuffle_size=shuffle_size,
)


dataset_size = 27727#dataset.reduce(0, lambda x, _: x + 1).numpy()
validation_size = int(0.3 * dataset_size)  # 0.3
train_dataset = dataset.take(validation_size).batch(batch)
validation_dataset = dataset.skip(validation_size).batch(batch)
test_dataset = load_dataset(pred_folder, label=False).shuffle(shuffle_size).batch(batch)


In [None]:
# %% train
run_time = datetime.today().strftime("%m_%d_%H_%M_%S")
!git checkout -b experiment_{run_time}
max_epochs = 200
MeanTeacherEpoch = 80
val_interval = 1
best_metric = -1
best_metric_epoch = -1
iter_num = 0
epoch_loss_values = []
post_pred = Compose([EnsureType(), AsDiscrete(argmax=True, to_onehot=NCLASS)])
post_label = Compose([EnsureType(), AsDiscrete(to_onehot=NCLASS)])
ckpt_path = os.path.join(model_folder, f"best_{run_time}.pth")
class_weights =  None #torch.tensor([1,1]).to(device)
criterion = dice_loss(class_weights=class_weights)
model = build_unet(len(BANDS), NCLASS).cuda()
ema_model = build_unet(len(BANDS), NCLASS).cuda()
model = nn.DataParallel(model)
ema_model = nn.DataParallel(ema_model)
model.to(device)
ema_model.to(device)
att_type = "channel" # 'channel', 'spatial', 'channel_spatial
lr = 0.001
opt = torch.optim.Adam(model.parameters(), lr=lr)
# scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(opt, mode='min', factor=0.1, patience=10, verbose=True)
dice_metric = DiceMetric(include_background=False, reduction="mean")
for epoch in range(max_epochs):
    print("-" * 10)
    print(f"epoch {epoch + 1}/{max_epochs}")
    start_time = time()
    model.train()
    epoch_loss = 0
    step = 0
    train_loader = train_dataset.as_numpy_iterator()
    val_loader = validation_dataset.as_numpy_iterator()
    unlabeled_train_loader = test_dataset.as_numpy_iterator()
    for labeled_batch, unlabeled_batch in zip(train_loader, unlabeled_train_loader):
        step += 1
        labeled_inputs, labels = (
            torch.tensor(labeled_batch[0]).to(device),
            torch.tensor(labeled_batch[1]).to(device),
        )
        unlabeled_batch = unlabeled_batch[
            slice(None), slice(None), X_BUFFER:X_BUFFERED, Y_BUFFER:Y_BUFFERED
        ]
        unlabeled_inputs = torch.tensor(unlabeled_batch).to(device)
        opt.zero_grad()
        noise_labeled = torch.clamp(torch.randn_like(labeled_inputs) * 0.1, -0.2, 0.2)
        noise_unlabeled = torch.clamp(
            torch.randn_like(unlabeled_inputs) * 0.1, -0.2, 0.2
        )
        noise_labeled_inputs = labeled_inputs + noise_labeled
        noise_unlabeled_inputs = unlabeled_inputs + noise_unlabeled

        outputs = model(labeled_inputs)
        with torch.no_grad():
            outputs_unlabeled = model(unlabeled_inputs)
            outputs_aug = ema_model(noise_labeled_inputs)
            outputs_unlabeled_aug = ema_model(noise_unlabeled_inputs)

        supervised_loss = criterion(outputs, labels, to_onehot_y=True)
        if epoch < MeanTeacherEpoch:
            consistency_loss = 0.0
        else:
            consistency_loss = torch.mean(
                criterion(outputs, outputs_aug, class_weights)
            ) + torch.mean(
                criterion(outputs_unlabeled, outputs_unlabeled_aug, class_weights)
            )
        consistency_weight = get_current_consistency_weight(
            (epoch - MeanTeacherEpoch), (max_epochs - MeanTeacherEpoch)
        )
        print(
            "consistency_weight is : {}, consistency_loss is: {}".format(
                consistency_weight, consistency_loss
            )
        )
        iter_num += 1
        loss = supervised_loss + consistency_weight * consistency_loss
        loss.backward()
        opt.step()
        update_ema_variables(model, ema_model, 0.99, iter_num)
        epoch_loss += loss.item()
        print(
            # f"{step}/{len(unlabeled_train_ds) // unlabeled_train_loader.batch_size}, "
            f"train_loss: {loss.item():.4f}"
        )

    epoch_loss /= step
    writer.add_scalar("Loss/train", epoch_loss, epoch)
    writer.add_scalar("supervised_loss/train", supervised_loss, epoch)
    writer.add_scalar("consistency_loss/train", consistency_loss, epoch)
    epoch_loss_values.append(epoch_loss)

    print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")

    if (epoch + 1) % val_interval == 0:
        model.eval()
        with torch.no_grad():
            val_metric = 0
            metric_values = []
            for val_data in val_loader:
                val_inputs, val_labels = (
                    torch.tensor(val_data[0]).to(device),
                    torch.tensor(val_data[1]).to(device),
                )
                val_outputs = model(val_inputs)
                val_metric = criterion(val_outputs, val_labels, to_onehot_y=True)
                metric_values.append(val_metric)

            # aggregate the final mean dice result
            val_loss = torch.mean(torch.tensor(metric_values)).numpy()
            metric = 1 - loss
            print(f"val dice: {metric}")
            writer.add_scalar("val dice", metric, epoch)
            # reset the status for next validation round
            # dice_metric.reset()

        if metric > best_metric:
            best_metric = metric
            best_metric_epoch = epoch + 1
            torch.save(
                model.module.state_dict(),
                ckpt_path,
            )
            print("saved new best metric model")
        print(
            f"current epoch: {epoch + 1} current mean dice: {metric:.4f}"
            f"\nbest mean dice: {best_metric:.4f} "
            f"at epoch: {best_metric_epoch}"
        )
    # scheduler.step(val_loss)
    print(f"epoch time = {time() - start_time}")
writer.close()
!git commit -m "Ran experiment {ckpt_path} and made changes to code"
!git checkout main

In [None]:
# ckpt_path = '/bess23/huaize/semi-supervised/models/best_03_24_02_40_15.pth'
tif_filename = "/bess23/huaize/semi-supervised/tif/{}.tif".format(
    ckpt_path.split("/")[-1]
)
print(ckpt_path)
!python /bess23/huaize/semi-supervised/meanteacherseg/toimage.py --ckpt_path "{ckpt_path}" --pred_folder {pred_folder} --tif_filename {tif_filename} --attention_type {att_type}


In [None]:
ground_truth_tiff  = 'earthengine/cdl_Winter_Wheat_South_Dakota_2022.tif'  
gt_data = rioxarray.open_rasterio(ground_truth_tiff)
gt_data = gt_data.where(gt_data != gt_data.rio.nodata)
pred_data = rioxarray.open_rasterio(tif_filename)
pred_data = pred_data.where(pred_data != pred_data.rio.nodata)
pred_data = pred_data.rio.reproject_match(gt_data, nodata=np.nan)


In [None]:
# Create an instance of MetricsEvaluator
evaluator = MetricsEvaluator(gt_data, pred_data)

# Compute F1 score
precision, recall, f1, overlay_f1_score = evaluator.f1_score()
print("Precision:", round(float(precision.data), 3))
print("Recall:", round(float(recall.data), 3))
print("F1 Score:", round(float(f1.data), 3))
print("Overlay F1 Score:", round(float(overlay_f1_score.data), 3))

# Compute Dice coefficient
# dice_coefficient = evaluator.dice_coefficient()
# print("Dice Coefficient:", round(float(dice_coefficient), 3))

# Compute average precision
# ap = evaluator.average_precision()
# print("Average Precision:", round(float(ap), 3))
