<a href="https://colab.research.google.com/github/Hydrometeorological-Remote-Sensing/cropmapping/blob/sw_ww_test/3_train_pytorch_colab.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from google.colab import drive

drive.mount("/content/drive")


In [None]:
!cp / content/drive/MyDrive/RDA/model_and_env/data_loader.py .
!cp / content/drive/MyDrive/RDA/model_and_env/unet_pytorch.py .


In [None]:
!pip install segmentation_models_pytorch - q
!pip install affine - q


In [None]:
import datetime
import segmentation_models_pytorch as smp
import numpy as np
import torch
from glob import glob
import torch.utils.data
from data_loader import data_load
import tensorflow as tf
from unet_pytorch import build_unet
import torch.nn as nn
from torch.utils.tensorboard import SummaryWriter
from copy import deepcopy
import pandas as pd
from torch import nn
import torch.nn.functional as F
from torch.utils.data import random_split, DataLoader, Dataset

import torchvision
from torchvision import transforms
import numpy as np
from tqdm.notebook import tqdm

import matplotlib.pyplot as plt

from sklearn.metrics import accuracy_score


In [None]:
writer = SummaryWriter()


def dataset(bands, description, batch, folder,targets,train=True, buffer_size=6000):
    if train:
        sub_folder = "train"
    else:
        sub_folder = "valid"
    tf_files = glob(
        folder + "/{}/*.gz".format(sub_folder)
    )

    print(tf_files)
    data = data_load(
        tf_files,
        bands,
        description,
        batch_size=batch,
        response=targets,
        buffer_size=buffer_size,
    ).get_training_dataset()
    return data


class StepRunner:
    def __init__(self, net, loss_fn, stage="train", metrics_dict=None, optimizer=None):
        self.net, self.loss_fn, self.metrics_dict, self.stage = (
            net,
            loss_fn,
            metrics_dict,
            stage,
        )
        self.optimizer = optimizer

    def step(self, features, labels):
        # loss
        preds = self.net(features)
        loss = self.loss_fn(preds, labels)

        # backward()
        if self.optimizer is not None and self.stage == "train":
            loss.backward()
            self.optimizer.step()
            self.optimizer.zero_grad()

        # metrics
        step_metrics = {
            self.stage + "_" + name: metric_fn(preds, labels).item()
            for name, metric_fn in self.metrics_dict.items()
        }
        return loss.item(), step_metrics

    def train_step(self, features, labels):
        self.net.train()  # 训练模式, dropout层发生作用
        return self.step(features, labels)

    @torch.no_grad()
    def eval_step(self, features, labels):
        self.net.eval()  # 预测模式, dropout层不发生作用
        return self.step(features, labels)

    def __call__(self, features, labels):
        if self.stage == "train":
            return self.train_step(features, labels)
        else:
            return self.eval_step(features, labels)


class EpochRunner:
    def __init__(self, steprunner):
        self.steprunner = steprunner
        self.stage = steprunner.stage

    def __call__(self, dataset,test_dataset):
        total_loss, step = 0, 1
        loop = enumerate(dataset.as_numpy_iterator())
        loop_unsup = enumerate(test_dataset.as_numpy_iterator())
        for _, batch in loop:
            x1_sup = torch.tensor(batch[0]).to(device)
            x2_sup = torch.tensor(batch[0]).to(device)
            y = torch.tensor(batch[1]).to(torch.int64).to(device)
            test_batch = next(loop_unsup)
            x1_unsup = torch.tensor(test_batch[x_buffer : x_buffer + kernel_shape[0]]).to(device)
            x2_unsup = torch.tensor(test_batch[y_buffer : y_buffer + kernel_shape[1]]).to(device)
            x1_concat = torch.cat([x1_sup, x1_unsup], dim=0)
            x2_concat = torch.cat([x2_sup, x2_unsup], dim=0)

            loss, step_metrics = self.steprunner(features, labels)
            step_log = dict({self.stage + "_loss": loss}, **step_metrics)
            total_loss += loss
            step += 1
            print(step_log, flush=True)

        epoch_loss = total_loss / step
        epoch_metrics = {
            self.stage + "_" + name: metric_fn.compute().item()
            for name, metric_fn in self.steprunner.metrics_dict.items()
        }
        epoch_log = dict({self.stage + "_loss": epoch_loss}, **epoch_metrics)
        print(epoch_log)

        for name, metric_fn in self.steprunner.metrics_dict.items():
            metric_fn.reset()
        return epoch_log


def printlog(info):
    nowtime = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
    print("\n" + "==========" * 4 + "%s" % nowtime + "==========" * 4)
    print(str(info) + "\n")


def train_model(
    net,
    optimizer,
    loss_fn,
    metrics_dict,
    train_data,
    val_data=None,
    epochs=30,
    ckpt_path=None,
    patience=5,
    monitor="val_loss",
    mode="min",
):

    history = {}

    for epoch in range(1, epochs + 1):
        printlog("Epoch {0} / {1}".format(epoch, epochs))

        # 1，train -------------------------------------------------
        train_step_runner = StepRunner(
            net=net,
            stage="train",
            loss_fn=loss_fn,
            metrics_dict=deepcopy(metrics_dict),
            optimizer=optimizer,
        )
        train_epoch_runner = EpochRunner(train_step_runner)
        train_metrics = train_epoch_runner(train_data)

        for name, metric in train_metrics.items():
            history[name] = history.get(name, []) + [metric]

        writer.add_scalar("Loss/train", history["train_loss"][-1], epoch)
        # 2，validate -------------------------------------------------
        if val_data:
            val_step_runner = StepRunner(
                net=net,
                stage="val",
                loss_fn=loss_fn,
                metrics_dict=deepcopy(metrics_dict),
            )
            val_epoch_runner = EpochRunner(val_step_runner)
            with torch.no_grad():
                val_metrics = val_epoch_runner(val_data)
            val_metrics["epoch"] = epoch
            for name, metric in val_metrics.items():
                history[name] = history.get(name, []) + [metric]
        writer.add_scalar("Loss/val", history["val_loss"][-1], epoch)
        # 3，early-stopping -------------------------------------------------
        arr_scores = history[monitor]
        best_score_idx = (
            np.argmax(arr_scores) if mode == "max" else np.argmin(arr_scores)
        )
        if best_score_idx == len(arr_scores) - 1:
            torch.save(net.state_dict(), ckpt_path)
            print(
                "<<<<<< reach best {0} : {1} >>>>>>".format(
                    monitor, arr_scores[best_score_idx]
                )
            )
        if len(arr_scores) - best_score_idx > patience:
            print(
                "<<<<<< {} without improvement in {} epoch, early stopping >>>>>>".format(
                    monitor, patience
                )
            )
            break
        net.load_state_dict(torch.load(ckpt_path))

    return pd.DataFrame(history)


In [None]:
if __name__ == "__main__":
    bands = [
        "blue",
        "green",
        "red",
        "nir",
        "swir1",
        "swir2",
        "ndvi",
        "nirv",
    ]
    targets = ["cropland"]
    features = bands + targets
    ckpt_path = (
        "/content/drive/MyDrive/RDA/model_and_env/sw_ww_test/ww_56_model/checkpoint.pt"
    )
    nclass = 2
    batch = 32
    class_names = ["Others", "Wheat"]
    columns = [
        tf.io.FixedLenFeature(shape=[256, 256], dtype=tf.float32) for k in features
    ]
    description = dict(zip(features, columns))
    metrics_dict = {}
    folder = "/content/drive/MyDrive/RDA/training cdl/data"
    training = dataset(bands, description, batch,folder, buffer_size=1000)
    valid = dataset(bands, description, batch, folder, train=False, buffer_size=1000)
    device = "cuda"
    loss_fn = smp.losses.DiceLoss(mode="multiclass", classes=[0, 1])
    loss_name = "DiceLoss"
    model = (
        build_unet(len(bands), nclass).cuda()
        if device == "cuda"
        else build_unet(len(bands), nclass)
    )
    for name, layer in model.named_modules():
        if isinstance(layer, torch.nn.Conv2d):
            print(name, layer)
    print(model.outputs)
    model = nn.DataParallel(model)
    optimizer = torch.optim.Adam(
        [
            dict(params=model.parameters(), lr=0.0001),
        ]
    )

    train_model(
        model,
        optimizer,
        loss_fn,
        metrics_dict,
        training,
        val_data=valid,
        epochs=10,
        ckpt_path=ckpt_path,
    )
