In [1]:
import pandas as pd
import matplotlib.pyplot as plt
import cv2 as cv
import rasterio
import os
import glob
import warnings
import random
import numpy as np
import time
# Suppress runtime warnings # HFDT-CONDA
warnings.filterwarnings("ignore", message="invalid value encountered in scalar divide")
import wandb

In [2]:
from torch.utils.data import Dataset
from PIL import Image
import torch
import torch.nn as nn

In [3]:
from training_dataloader import CustomDataset, IterableCustomDataset
from model import CombinedModel
from torch.utils.data import DataLoader, random_split

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

In [5]:
config = {
    "project": "Crop_Yield",
    "group": "MLCAS_Crop_Yield",
    "name": "MLCAS_Crop_Yield_v9_001",
    "device": device,
    "warmup_steps": 2_000,
    "total_updates": 3_000,
    "validation_frequency": 20,
    "batch_size": 32,
    "val_ratio": 0.2,
    "lr": 2e-5,
    "weight_decay": 1e-4,
    "betas": (0.9, 0.999),
}

In [6]:
dataset = CustomDataset(data_path=".././2022/DataPublication_final/GroundTruth/HYBRID_HIPS_V3.5_ALLPLOTS.csv", 
                 date_path=".././2022/DataPublication_final/GroundTruth/DateofCollection.xlsx",
                 mother_path=".././2022/DataPublication_final/",
                 device=device)

In [7]:
dataset2 = CustomDataset(device=device)


In [8]:
merged_dataset = torch.utils.data.ConcatDataset([dataset, dataset2])
len(merged_dataset)

2774

In [9]:
test_ratio = config["val_ratio"]
test_size = int(len(merged_dataset) * test_ratio)
train_size = len(merged_dataset) - test_size

train_dataset, test_dataset = random_split(merged_dataset, [train_size, test_size])
print(len(train_dataset), len(test_dataset))
train_dataset = IterableCustomDataset(train_dataset)

batch_size = config["batch_size"]
train_loader = DataLoader(train_dataset, batch_size=batch_size, pin_memory=True, num_workers=6)
test_loader = DataLoader(test_dataset, batch_size=batch_size, pin_memory=True, num_workers=6, shuffle=False)

2220 554




In [10]:
train_dataset_path = "train_dataset.pth"
test_dataset_path = "test_dataset.pth"

torch.save(train_dataset, train_dataset_path)
torch.save(test_dataset, test_dataset_path)

In [11]:
def calculate_val_mse(model, val_loader):
    model.eval()
    val_loss = 0

    with torch.no_grad():
        for input_data, target_data, info in val_loader:
            satellite_images = info['satelliteImages'].to(device)
            satellite_images_info = info['satelliteImagesInfo'].to(device)
            satellite_images_len = info['satelliteImagesLen'].to(device)
            this_pred = model(input_data, satellite_images, satellite_images_info, satellite_images_len).cpu().numpy()
            loss = np.mean((this_pred.reshape(-1, 1) - target_data.cpu().numpy().reshape(-1, 1))**2)
            val_loss += loss

    model.train()
    return val_loss

In [12]:
model = CombinedModel(device=device)
model.train()



CombinedModel(
  (image_encoder): ResNetEncoder(
    (resnet): ResNet(
      (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (layer1): Sequential(
        (0): BasicBlock(
          (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (1): BasicBlock(
          (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
         

In [13]:
for param in model.feature_encoder.model.parameters():
    param.requires_grad = False

for layer in model.feature_encoder.model.encoder.layer[-4:]:
    for param in layer.parameters():
        param.requires_grad = True

for layer in model.feature_encoder.model.encoder.layer[0:1]:
    for param in layer.parameters():
        param.requires_grad = True

for param in model.feature_encoder.model.embeddings.visual_projection.parameters():
    param.requires_grad = True

In [14]:
for param in model.image_encoder.resnet.parameters():
    param.requires_grad = False

for layer in list(model.image_encoder.resnet.children())[-4:]:
    for param in layer.parameters():
        param.requires_grad = True

In [15]:
sum(p.numel() for p in model.parameters())

124885825

In [16]:
wandb.init(
    config=config,
    project=config["project"],
    group=config["group"],
    name=config["name"],
)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mprzl[0m ([33mprzl101[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [17]:
total_updates = config["total_updates"]
warmup_steps = config["warmup_steps"]

In [18]:
criterion = nn.MSELoss()

optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=config["lr"],
        weight_decay=config["weight_decay"],
        betas=config["betas"],
    )

scheduler = torch.optim.lr_scheduler.LambdaLR(
    optimizer,
    lambda steps: min((steps + 1) / warmup_steps, 1),
)

In [19]:
validation_frequency = config["validation_frequency"]
train_losses = []
val_losses = []
best_val_loss = float('inf')
trainloader_iter = iter(train_loader)

for step in range(0, total_updates):
    step_start_time = time.time()
    
    # Training loop
    inputs, targets, info = next(trainloader_iter)
    optimizer.zero_grad()
    outputs = model(inputs, info['satelliteImages'], info['satelliteImagesInfo'], info['satelliteImagesLen'].to(device))
    loss = criterion(outputs.to(device).reshape(-1, 1), targets.float().to(device).reshape(-1, 1))
    loss.backward()
    optimizer.step()
    scheduler.step()
    
    wandb.log(
        {
            "train_loss": loss.item(),
            "learning_rate": scheduler.get_last_lr()[0],
        },
        step=step,
    )
    avg_train_loss = loss.item()
    train_losses.append(avg_train_loss)
    
    step_end_time = time.time()
    step_time = step_end_time - step_start_time
    
    print(f'Step [{step+1}/{total_updates}], Train Loss: {avg_train_loss:.4f}, Time Elapsed: {step_time:.2f} seconds')
    
    if (step + 1) % validation_frequency == 0:
        print("")
        print("---"*20)
        val_loss = calculate_val_mse(model, test_loader)
        val_losses.append(val_loss)
        print(f'Validation Loss after step {step+1}: {val_loss:.4f}')
        torch.save(model.state_dict(), f"./prediction_models/saved_model_{step+1}.pth")
        wandb.log(
            {
                "val_loss": val_loss,
            },
            step=step,
        )
        if val_loss < best_val_loss:
            best_model_path = f"./prediction_models/best_val_model_{step+1}.pth" 
            best_val_loss = val_loss
            torch.save(model.state_dict(), best_model_path)
            print(f'Saved model with validation loss: {best_val_loss:.4f}')
        print("---"*20)

print(f'Best model saved at: {best_model_path}')
print(best_val_loss)



Step [1/3000], Train Loss: 0.3456, Time Elapsed: 44.63 seconds
Step [2/3000], Train Loss: 0.4508, Time Elapsed: 0.50 seconds
Step [3/3000], Train Loss: 0.2487, Time Elapsed: 0.46 seconds
Step [4/3000], Train Loss: 0.5372, Time Elapsed: 0.65 seconds
Step [5/3000], Train Loss: 0.3839, Time Elapsed: 0.48 seconds
Step [6/3000], Train Loss: 0.3551, Time Elapsed: 1.09 seconds
Step [7/3000], Train Loss: 0.4503, Time Elapsed: 34.13 seconds
Step [8/3000], Train Loss: 0.4571, Time Elapsed: 1.98 seconds
Step [9/3000], Train Loss: 0.3728, Time Elapsed: 0.47 seconds
Step [10/3000], Train Loss: 0.2293, Time Elapsed: 1.09 seconds
Step [11/3000], Train Loss: 0.4400, Time Elapsed: 0.45 seconds
Step [12/3000], Train Loss: 0.4028, Time Elapsed: 2.25 seconds
Step [13/3000], Train Loss: 0.5550, Time Elapsed: 32.52 seconds
Step [14/3000], Train Loss: 0.4248, Time Elapsed: 3.40 seconds
Step [15/3000], Train Loss: 0.4005, Time Elapsed: 0.47 seconds
Step [16/3000], Train Loss: 0.5147, Time Elapsed: 2.10 second