In [1]:
import torch
from Utils.Dataloader import Satellite_image_dataset, get_data
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt
from Models.U_net import U_net
# from Models.wmm import WMM
from Models.WMM import WMM
from sklearn.metrics import accuracy_score, f1_score, cohen_kappa_score
import torch.nn as nn
from Utils.Helper import MultiTrainHelper, train_mission

%matplotlib inline
%load_ext autoreload
%autoreload 2

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [3]:
def train_net(train_loader, classifier, criterion, optimizer):
  classifier.train()
  loss_ = 0.0
  losses = []
  for i, (images, labels) in enumerate(train_loader):
      images, labels = images.to(DEVICE), labels.to(DEVICE)
      optimizer.zero_grad()
      logits = classifier(images)
      loss = criterion(logits, labels.squeeze())
      loss.backward()
      optimizer.step()
      losses.append(loss)
  return torch.stack(losses).mean().item()

def test_net(test_loader, classifier, criterion,):
    classifier.eval()
    losses = []
    pred_list = []
    label_list = []
    with torch.no_grad():
        for i, (images, labels) in enumerate(test_loader):
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            logits = classifier(images)
            loss = criterion(logits, labels.squeeze())
            pred = torch.where(torch.sigmoid(logits)>=0.5, 1, 0).cpu().numpy()
            pred = np.reshape(pred, (-1,))
            labels = np.reshape(labels.cpu().numpy(), (-1,))
            label_list.append(labels)
            pred_list.append(pred)
            losses.append(loss.item())
        
        all_preds = np.concatenate(pred_list)
        all_labels = np.concatenate(label_list)
        acc = accuracy_score(all_labels, all_preds)
        f1 = f1_score(all_labels, all_preds)
        coppa = cohen_kappa_score(all_labels, all_preds)
        test_loss = np.mean(losses)
        print("Test result:\n test_loss: {:.4f}, accuracy: {:.4f}, f1: {:.4f}, coppa: {:.4f}".format(test_loss, acc, f1, coppa))
        
    return test_loss, acc, f1, coppa
def plot_losses(train, val, test_frequency, num_epochs):
    plt.plot(train, label="train")
    indices = [i for i in range(num_epochs) if ((i+1)%test_frequency == 0 or i ==0)]
    plt.plot(indices, val, label="val")
    plt.title("Loss Plot")
    plt.ylabel("Loss")
    plt.xlabel("Epoch")
    plt.legend()
    plt.show()

def plot_acc(train, val, test_frequency, num_epochs):
    indices = [i for i in range(num_epochs) if ((i+1)%test_frequency == 0 or i ==0)]
    plt.plot(indices, train, label="train")
    plt.plot(indices, val, label="val")
    plt.title("Accuracy Plot")
    plt.ylabel("Accuracy")
    plt.xlabel("Epoch")
    plt.legend()
    plt.show()

def plot_f1(train, val, test_frequency, num_epochs):
    indices = [i for i in range(num_epochs) if ((i+1)%test_frequency == 0 or i ==0)]
    plt.plot(indices, train, label="train")
    plt.plot(indices, val, label="val")
    plt.title("F1 Plot")
    plt.ylabel("F1")
    plt.xlabel("Epoch")
    plt.legend()
    plt.show()
def train(classifier, num_epochs, train_loader, val_loader, criterion, optimizer, scheduler, test_frequency=5):
    train_losses = []
    train_accs = []
    val_losses = []
    val_accs = []
    train_f1s =[]
    val_f1s = []

    for epoch in range(1, num_epochs+1):
        print("Starting epoch number " + str(epoch))
        train_loss = train_net(train_loader, classifier, criterion, optimizer)
        train_losses.append(train_loss)
        print("Loss for Training on Epoch " +str(epoch) + " is "+ str(train_loss))
        if(epoch%test_frequency==0 or epoch==1):
            print('Evaluating classifier')
            loss_train, acc_train, f1_train, _ = test_net(train_loader, classifier, criterion)
            train_accs.append(acc_train)
            train_f1s.append(f1_train)
            loss_test, acc_test, f1_test, _ = test_net(val_loader, classifier, criterion)
            val_losses.append(loss_test)
            val_accs.append(acc_test)
            val_f1s.append(f1_test)
        scheduler.step(loss_train)
    
    return classifier, train_losses, val_losses, train_accs, val_accs, train_f1s, val_f1s


In [4]:
batch_size = 8
learning_rate = 1e-3
weight_decay = 1e-4
num_epochs = 60
test_frequency = 1

In [5]:
train_mission_list = []
idx = 0
for type in ['Both','S1', ]:
  for site in [ 1, 3,5,7,9]:
    mission = train_mission(train_site=[site, ], test_site=[site, ] , idx= idx, type=type, model_name='WMM')
    train_mission_list.append(mission)
    idx += 1

In [6]:
for mission in train_mission_list:
    mission.mission_start()
    torch.cuda.empty_cache()
    Train_data = Satellite_image_dataset(sites=mission.train_site, years=mission.train_years, type=mission.type, model="ConvLSTM")
    Train_dataloader = DataLoader(Train_data, batch_size, shuffle = True)
    Test_data = Satellite_image_dataset(sites=mission.test_site, years = mission.test_years, type=mission.type, model="ConvLSTM" )
    Test_dataloader = DataLoader(Test_data, batch_size, shuffle = True)
    net = WMM(n_channels = 80, n_classes =1, timesteps = 10, n_convlstm = 1, n_feature_maps=120).to(DEVICE)
    criteria = nn.BCEWithLogitsLoss()
    optimizer = torch.optim.Adam(net.parameters(), learning_rate,weight_decay= weight_decay)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer=optimizer, factor = 0.25, patience = 3, mode='min', verbose=True)
    classifier, train_losses, val_losses, train_accs, val_accs, train_f1s, val_f1s = train(net, num_epochs, Train_dataloader, Test_dataloader, criteria, optimizer, scheduler, test_frequency)

    mission.mission_get_score([0,0,1,1],[1,0,0,1],[0,0,1,1],[1,0,0,1])


2022-05-04 13:00:14 Index:0, Model: WMM, train_site: 1, test_site: 1, type: Both
Starting epoch number 1
Loss for Training on Epoch 1 is 0.4372395873069763
Evaluating classifier
Test result:
 test_loss: 0.6044, accuracy: 0.7569, f1: 0.7677, coppa: 0.5338
Test result:
 test_loss: 0.5807, accuracy: 0.8213, f1: 0.7979, coppa: 0.6424
2022-05-04 13:02:11 Index:1, Model: WMM, train_site: 3, test_site: 3, type: Both


  _warn_prf(
  k = np.sum(w_mat * confusion) / np.sum(w_mat * expected)


Starting epoch number 1
Loss for Training on Epoch 1 is 0.5111772418022156
Evaluating classifier


KeyboardInterrupt: 