In [1]:
from neuralop.models import FNO
from fno_dataset import *

In [25]:
from datetime import datetime
train_loader = SEVIRDataLoader(start_date=datetime(2018, 1, 1), end_date=datetime(2018, 5, 1), batch_size=4, shuffle=True, seq_len=20, stride = 10)
test_loader = SEVIRDataLoader(start_date=datetime(2017, 1, 1), end_date=datetime(2017, 7, 1), batch_size=4, shuffle=False, seq_len=20, stride = 4)
val_loader = SEVIRDataLoader(start_date=datetime(2019, 1, 1), end_date=datetime(2019, 9, 1), batch_size=4, shuffle=False, seq_len=20, stride = 4)
print("Train dataset size:", len(train_loader))
print("Test dataset size:", len(test_loader))
print("Validation dataset size:", len(val_loader))

Train dataset size: 1383
Test dataset size: 990
Validation dataset size: 9634


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
import matplotlib.pyplot as plt
import matplotlib as mpl
import numpy as np
import os

os.makedirs('fno_plots', exist_ok=True)

model = FNO(n_modes=(128, 128), hidden_channels=64,
               in_channels=10, out_channels=10)
model = model.to('cuda')

criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

train_losses = []
test_losses = []
num_epochs = 1

vil_colormap, vil_norm = vil_cmap(encoded=True)

for epoch in range(num_epochs):
    model.train()
    epoch_train_losses = []
    pbar = tqdm(enumerate(train_loader), total=len(train_loader), desc=f'Epoch {epoch+1}/{num_epochs} [Train]')
    
    for batch_idx, data in pbar:
        vil_sequence = data['vil']
        vil_sequence = vil_sequence.to('cuda')
        vil_sequence = vil_sequence.permute(0, 3, 1, 2)
        assert vil_sequence.shape[1] == 20, f"Expected 10 channels, got {vil_sequence.shape[1]}"
        input_frames = vil_sequence[:, :10]
        target_frames = vil_sequence[:, 10:]
        
        predicted_frames = model(input_frames)
        loss = criterion(predicted_frames, target_frames)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        epoch_train_losses.append(loss.item())
        pbar.set_postfix({'Loss': f'{loss.item():.4f}'})
        
    model.eval()
    epoch_test_losses = []
    with torch.no_grad():
        test_pbar = tqdm(enumerate(test_loader), total=len(test_loader), desc=f'Epoch {epoch+1}/{num_epochs} [Test]')
        for batch_idx, data in test_pbar:
            vil_sequence = data['vil']
            vil_sequence = vil_sequence.to('cuda')
            vil_sequence = vil_sequence.permute(0, 3, 1, 2)
            assert vil_sequence.shape[1] == 20, f"Expected 10 channels, got {vil_sequence.shape[1]}"
            input_frames = vil_sequence[:, :10]
            target_frames = vil_sequence[:, 10:]
            
            predicted_frames = model(input_frames)
            loss = criterion(predicted_frames, target_frames)
            
            epoch_test_losses.append(loss.item())
            test_pbar.set_postfix({'Loss': f'{loss.item():.4f}'})
            
            if batch_idx % 100 == 0:
                fig, axes = plt.subplots(2, 2, figsize=(12, 10))
                
                # Convert back to 0-255 range for proper colormap display
                pred_np = predicted_frames[0, :2].detach().cpu().numpy() * 255
                target_np = target_frames[0, :2].detach().cpu().numpy() * 255
                
                im1 = axes[0, 0].imshow(target_np[0], cmap=vil_colormap, norm=vil_norm)
                axes[0, 0].set_title('Target Frame 1')
                axes[0, 0].axis('off')
                
                im2 = axes[0, 1].imshow(pred_np[0], cmap=vil_colormap, norm=vil_norm)
                axes[0, 1].set_title('Predicted Frame 1')
                axes[0, 1].axis('off')
                
                im3 = axes[1, 0].imshow(target_np[1], cmap=vil_colormap, norm=vil_norm)
                axes[1, 0].set_title('Target Frame 2')
                axes[1, 0].axis('off')
                
                im4 = axes[1, 1].imshow(pred_np[1], cmap=vil_colormap, norm=vil_norm)
                axes[1, 1].set_title('Predicted Frame 2')
                axes[1, 1].axis('off')
                
                plt.suptitle(f'Testing - Epoch {epoch+1}, Batch {batch_idx}')
                plt.tight_layout()
                plt.savefig(f'fno_plots/test_predictions_epoch_{epoch+1}_batch_{batch_idx}.png', dpi=300, bbox_inches='tight')
                plt.close()
    
    avg_train_loss = np.mean(epoch_train_losses)
    avg_test_loss = np.mean(epoch_test_losses)
    train_losses.append(avg_train_loss)
    test_losses.append(avg_test_loss)
    
    print(f'Epoch {epoch+1} - Train Loss: {avg_train_loss:.4f}, Test Loss: {avg_test_loss:.4f}')

plt.figure(figsize=(12, 6))
plt.plot(train_losses, 'b-', linewidth=2, label='Train Loss')
plt.plot(test_losses, 'r-', linewidth=2, label='Test Loss')
plt.title('Training and Testing Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)
plt.savefig('fno_plots/training_curves.png', dpi=300, bbox_inches='tight')
plt.close()

print("Training completed! Plots saved in fno_plots/")