In [1]:
print("Importing misc libraries")
import sys
import os
import torch
import tqdm
from pathlib import Path
import matplotlib.pyplot as plt

print("Updating sys.path")
project_root = Path.cwd().parent.parent
if str(project_root) not in sys.path:
    sys.path.append(str(project_root))
models_path = Path.cwd().parent
if str(models_path) not in sys.path:
    sys.path.append(str(models_path))
    
print("Importing torch libraries")
from torch.utils.data import DataLoader
import multiprocessing
from torchvision.transforms import ToTensor
import torch
from torch.optim import Adam
import torch.nn as nn
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
from torch.amp import autocast, GradScaler
    
print("Importing unet lib")
import importlib
import unet.Unet as u
importlib.reload(u)

print("Importing dataset lib")
from dataset.IntersectionDataset import IntersectionDataset, IntersectionDataset2, IntersectionDatasetClasses, custom_collate_fn
import loss.loss_lib as ll
importlib.reload(ll)







In [None]:
import cv2
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
img = cv2.imread("../unet/satellite.png")
img_t = torch.tensor(img).permute(2, 0, 1).unsqueeze(0).float().to(device)
model = u.UNet(n_channels=3, n_classes=1).to(device)
output = model(img_t)
print("Output shape:", output.shape)  # Expected: torch.Size([1, 1, 400, 400])
#u.display_output(output)

## Model

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = u.UNet(n_channels=3, n_classes=1).to(device)

## Dataset

In [None]:
dataset_dir = "../../dataset/dataset/train"
img_transform = ToTensor()
path_transform = ToTensor()
dataset_train = IntersectionDataset(root_dir=dataset_dir,
                              transform=img_transform,
                              path_transform=path_transform)

dataset_dir = "../../dataset/dataset/test"
img_transform = ToTensor()
path_transform = ToTensor()
dataset_test = IntersectionDataset(root_dir=dataset_dir,
                              transform=img_transform,
                              path_transform=path_transform)

In [None]:
print(len(dataset_train))
print(len(dataset_test))



## Dataloader

In [None]:
num_workers = multiprocessing.cpu_count()
b = 4

# split dataset into train and test
# train_size = int(0.8 * len(dataset))
# test_size = len(dataset) - train_size
# train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])

train_dataloader = DataLoader(dataset_train, batch_size=b, shuffle=True, num_workers=num_workers)
test_dataloader = DataLoader(dataset_test, batch_size=b, shuffle=True, num_workers=num_workers)

## Optimizer

In [None]:
optimizer = Adam(model.parameters(), lr=1e-3)

## Scheduler

In [None]:
scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=20, T_mult=1, eta_min=1e-5)

## Loss

In [None]:
alpha = 0.5
lc = ll.CmapLoss().to(device)
lb = ll.BCELoss()

def total_loss(output, target, alpha = 0.5):
    loss = alpha * lc(output, target) + (1 - alpha) * lb(output, target)
    return loss.to(device)

# Training loop

In [None]:
n_epochs = 500
alpha = 0.5
epochs = tqdm.tqdm(range(n_epochs))

In [None]:
lb = torch.nn.BCEWithLogitsLoss()
lb = lb.to(device)

In [None]:
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

In [None]:
for epoch in epochs:
    model.train()
    for batch in train_dataloader:
        satellite = batch["satellite"].to(device)
        path_line = batch["path_line"].to(device)
        cold_map = batch["cold_map"].to(device)
        cmap_f = torch.flatten(cold_map)
        output = model(satellite)
        L_cmap = lc(cmap_f, output)
        L_bce = lb(output, path_line)
        loss = alpha * L_cmap + (1 - alpha) * L_bce
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        


In [None]:
torch.save(model.state_dict(), "model.pth")

In [None]:
u.display_output(output, threshold=0.5, thresholded=False)

In [None]:
print(output[0].shape, output.max(), output.min())

In [None]:
import matplotlib.pyplot as plt
import random

ckpts = sorted([f for f in os.listdir("ckpt") if f.startswith("bce_new") and f.endswith(".pth")])
print(ckpts)
ckpt = torch.load("ckpt/checkpoint_epoch_1000.pth")

#model.load_state_dict(ckpt["model_state_dict"])
#model.eval()

rs = [r for r in random.sample(range(len(dataset_test)), 6)]
print(rs)

cols = 6
rows = 2
fig = plt.figure(figsize=(24, 8))
i = 1
c = 0

for ck in ckpts:
    i = 1
    cnt = 0
    
    cols = 6
    rows = 2
    fig = plt.figure(figsize=(24, 8))
    
    c = torch.load("./ckpt/checkpoint_epoch_100.pth")
    
    model2 = u.UNet(n_channels=3, n_classes=1).to(device)
    model2.load_state_dict(c["model_state_dict"])
    model2.eval()
    
    while i <= rows * cols:
        r = random.randint(0, len(dataset_test))
        sat = dataset_test[rs[cnt]]["satellite"].permute(1, 2, 0)
        fig.add_subplot(rows, cols, i)
        plt.imshow(sat)
        plt.axis("off")
        
        i += 1
        
        fig.add_subplot(rows, cols, i)
        
        s = dataset_test[rs[cnt]]["satellite"].to(device)
        #print(s.shape)
        s = s.unsqueeze(0)
        output = model2(s)
        output = torch.sigmoid(output)
        
        output = output.squeeze(0).squeeze(0).detach().cpu().numpy()
        
        plt.imshow(output, cmap="gray")
        plt.axis("off")
        
        i += 1
        cnt += 1
    
    fig.suptitle(f"{ck}")
    fig.tight_layout()
    
    #fig.savefig(f"ckpt/{ck}.png")
    plt.show()
    
    break
    









## Add statistics collection

In [None]:
del(model)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = u.UNet(n_channels=3, n_classes=1).to(device)

del(optimizer)
optimizer = Adam(model.parameters(), lr=1e-3)

del(scheduler)
scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=20, T_mult=1, eta_min=1e-5)

In [None]:
train_losses = []
test_losses = []
train_accuracies = []
test_accuracies = []

n_epochs = 200
alpha = 0.5
epochs = tqdm.tqdm(range(n_epochs))

os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

for epoch in epochs:
    model.train()
    running_train_loss = 0.0
    running_train_correct = 0
    running_train_total = 0
    
    for batch in train_dataloader:
        satellite = batch["satellite"].to(device)
        path_line = batch["path_line"].to(device)
        path_line_ee = batch["ee_data"]
        #cold_map = batch["cold_map"].to(device)
        #cmap_f = torch.flatten(cold_map)
        output = model(satellite)
        #L_cmap = lc(cmap_f, output)
        L_bce = lb(output, path_line)
        #loss = alpha * L_cmap + (1 - alpha) * L_bce
        loss = L_bce
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        running_train_loss += loss.item()
        
        p = torch.sigmoid(output)
        p = (p > 0.5).float()
        running_train_correct += (p == path_line).sum().item()
        running_train_total += path_line.numel()
        
    avg_train_loss = running_train_loss / len(train_dataloader)
    train_losses.append(avg_train_loss)
    train_accuracy = running_train_correct / running_train_total
    train_accuracies.append(train_accuracy)
    
    model.eval()
    running_test_loss = 0.0
    running_test_correct = 0
    running_test_total = 0
    with torch.no_grad():
        for batch in test_dataloader:
            satellite = batch["satellite"].to(device)
            path_line = batch["path_line"].to(device)
            path_line_ee = batch["ee_data"]

            #cold_map = batch["cold_map"].to(device)
            #cmap_f = torch.flatten(cold_map)
            output = model(satellite)
            #L_cmap = lc(cmap_f, output)
            L_bce = lb(output, path_line)
            #loss = alpha * L_cmap + (1 - alpha) * L_bce
            loss = L_bce
            
            running_test_loss += loss.item()
            
            p = torch.sigmoid(output)
            p = (p > 0.5).float()
            running_test_correct += (p == path_line).sum().item()
            running_test_total += path_line.numel()
            
    avg_test_loss = running_test_loss / len(test_dataloader)
    test_losses.append(avg_test_loss)
    
    test_accuracy = running_test_correct / running_test_total
    test_accuracies.append(test_accuracy)
    
    scheduler.step()
    
    if (epoch + 1) % 50 == 0:
        checkpoint = {
            'epoch': epoch + 1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'train_losses': train_losses,
            'test_losses': test_losses,
            'train_accuracies': train_accuracies,
            'test_accuracies': test_accuracies,
        }
        os.makedirs('./ckpt', exist_ok=True)
        torch.save(checkpoint, f'./ckpt/bce_new_dataset_checkpoint_epoch_{epoch + 1}.pth')
    
    
    #print(f"Epoch {epoch+1}/{n_epochs} - Train Loss: {avg_train_loss:.4f} | Test Loss: {avg_test_loss:.4f}")
    
torch.save(model.state_dict(), "model_200e_bce_new_dataset.pth")
            

In [None]:
# Plotting the loss graphs
plt.figure(figsize=(10, 6))
plt.plot(range(1, n_epochs + 1), train_losses, label="Train Loss")
plt.plot(range(1, n_epochs + 1), test_losses, label="Test Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Training and Evaluation Loss Over Epochs")
plt.legend()
plt.show()

# Plotting the accuracy graphs
plt.figure(figsize=(10, 6))
plt.plot(range(1, n_epochs + 1), train_accuracies, label="Train Accuracy")
plt.plot(range(1, n_epochs + 1), test_accuracies, label="Test Accuracy")
plt.xlabel("Epoch")
plt.ylabel("Accuracy")
plt.title("Training and Evaluation Accuracy Over Epochs")
plt.legend()
plt.show()

## Three classes

In [None]:
def map_exit_to_class(exit_x, exit_y):
    if exit_y == 0:
        return 1  # left
    elif exit_y == 399:
        return 2  # right
    elif exit_x == 0:
        return 3  # ahead
    else:
        raise ValueError("Unexpected exit position")


In [None]:
def init_weights(m):
    if isinstance(m, nn.Conv2d) and m.kernel_size == (1, 1):
        nn.init.normal_(m.weight, mean=0, std=0.001)
        if m.bias is not None:
            nn.init.constant_(m.bias, 0)
    elif isinstance(m, nn.Conv2d):
        nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
        if m.bias is not None:
            nn.init.constant_(m.bias, 0)
    elif isinstance(m, nn.BatchNorm2d):
        nn.init.constant_(m.weight, 1)
        nn.init.constant_(m.bias, 0)
    else:
        pass
    

In [None]:
dataset_dir = "../../dataset/dataset/train"
img_transform = ToTensor()
path_transform = ToTensor()
dataset_train = IntersectionDataset(root_dir=dataset_dir, 
                                    transform=img_transform,
                                    path_transform=path_transform)

dataset_dir = "../../dataset/dataset/test"
img_transform = ToTensor()
path_transform = ToTensor()
dataset_test = IntersectionDataset(root_dir=dataset_dir,
                                   transform=img_transform,
                                   path_transform=path_transform)

num_workers = multiprocessing.cpu_count()
b = 4

train_dataloader = DataLoader(dataset_train, batch_size=b, shuffle=True, num_workers=num_workers)
test_dataloader = DataLoader(dataset_test, batch_size=b, shuffle=True, num_workers=num_workers)

#os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

In [None]:
try: 
    del(model)
except NameError:
    pass
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = u.UNet(n_channels=3, n_classes=5).to(device) # background, left, right, ahead, stacked
model.apply(init_weights)

try:
    del(optimizer)
except NameError:
    pass
optimizer = Adam(model.parameters(), lr=1e-4)

try:
    del(scheduler)
except NameError:
    pass
scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=20, T_mult=1, eta_min=1e-5)

lb = torch.nn.CrossEntropyLoss()
lb = lb.to(device)

In [None]:
train_losses = []
test_losses = []
train_accuracies = []
test_accuracies = []

n_epochs = 10
alpha = 0.5
epochs = tqdm.tqdm(range(n_epochs), desc="Training")

os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

for epoch in epochs:
    model.train()
    running_train_loss = 0.0
    running_train_correct = 0
    running_train_total = 0
    
    for batch in train_dataloader:
        satellite = batch["satellite"].to(device)
        path_line = batch["path_line"].to(device)
        path_line_ee = batch["ee_data"]
        
        B, _, H, W = path_line.shape
        target = torch.full((B, H, W), 0, dtype=torch.long, device=device)
        
        for i in range(B):
            exit_x = path_line_ee["exit"]["x"][i].item()
            exit_y = path_line_ee["exit"]["y"][i].item()
            
            class_label = map_exit_to_class(exit_x, exit_y)
            
            mask = path_line[i, 0] > 0
            
            
            target[i, mask] = class_label
        
        output = model(satellite)
        L_ce = lb(output, target)
        loss = L_ce
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        running_train_loss += loss.item()
        
        p = torch.argmax(output, dim=1)
        valid_mask = target != 0
        running_train_correct += (p[valid_mask] == target[valid_mask]).sum().item()
        running_train_total += valid_mask.sum().item()
        
    avg_train_loss = running_train_loss / len(train_dataloader)
    train_losses.append(avg_train_loss)
    train_accuracy = running_train_correct / running_train_total
    train_accuracies.append(train_accuracy)
    
    model.eval()
    running_test_loss = 0.0
    running_test_correct = 0
    running_test_total = 0
    with torch.no_grad():
        for batch in test_dataloader:
            satellite = batch["satellite"].to(device)
            path_line = batch["path_line"].to(device)
            path_line_ee = batch["ee_data"]

            B, _, H, W = path_line.shape
            target = torch.full((B, H, W), 0, dtype=torch.long, device=device)
            
            for i in range(B):
                exit_x = path_line_ee["exit"]["x"][i].item()
                exit_y = path_line_ee["exit"]["y"][i].item()
                
                class_label = map_exit_to_class(exit_x, exit_y)
                
                mask = path_line[i, 0] > 0
                
                target[i, mask] = class_label
            
            output = model(satellite)
            L_ce = lb(output, target)
            loss = L_ce
            
            running_test_loss += loss.item()
            
            p = torch.argmax(output, dim=1)
            valid_mask = target != 0
            running_test_correct += (p[valid_mask] == target[valid_mask]).sum().item()
            running_test_total += valid_mask.sum().item()
            
    avg_test_loss = running_test_loss / len(test_dataloader)
    test_losses.append(avg_test_loss)
    
    test_accuracy = running_test_correct / running_test_total
    test_accuracies.append(test_accuracy)
    
    scheduler.step()  
    
    if (epoch + 1) % 25 == 0:
        checkpoint = {
            'epoch': epoch + 1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'train_losses': train_losses,
            'test_losses': test_losses,
            'train_accuracies': train_accuracies,
            'test_accuracies': test_accuracies,
        }
        os.makedirs('./ckpt', exist_ok=True)
        torch.save(checkpoint, f'./ckpt/bce_checkpoint_epoch_{epoch + 1}_4classes.pth')
        
    print(f"Epoch {epoch+1}/{n_epochs} - Train Loss: {avg_train_loss:.4f} | Test Loss: {avg_test_loss:.4f}")
    
    
#torch.save(model.state_dict(), "model_200e_ce_new_dataset_3class.pth")

In [None]:
train_losses = []
test_losses = []
train_accuracies = []
test_accuracies = []

n_epochs = 10
alpha = 0.5
epochs = tqdm.tqdm(range(n_epochs))




    
for batch in train_dataloader:
    satellite = batch["satellite"].to(device)
    path_line = batch["path_line"].to(device)
    path_line_ee = batch["ee_data"]
    
    l = len(batch) -1
    
    B, _, H, W = path_line.shape
    target = torch.full((B, H, W), 0, dtype=torch.long, device=device)
    
    for i in range(B):
        exit_x = path_line_ee["exit"]["x"][i].item()
        exit_y = path_line_ee["exit"]["y"][i].item()
        
        class_label = map_exit_to_class(exit_x, exit_y)
        
        mask = path_line[i, 0] > 0
        
        target[i, mask] = class_label
        
    plt.imshow(target[l].cpu().numpy())
    plt.title(f"Class: {class_label}, Exit: ({exit_x}, {exit_y})")
    plt.colorbar()
    plt.show()
        
        
    
    output = model(satellite)
    L_ce = lb(output, target)
    loss = L_ce
    
    print(f"{output.shape}, {output.max()}, {output.min()}")
    print(f"{target.shape}, {target.max()}, {target.min()}")
    
    # optimizer.zero_grad()
    # loss.backward()
    # optimizer.step()
    
    p = torch.argmax(output, dim=1)
    plt.imshow(p[l].detach().cpu().numpy())
    plt.title(f"Predicted")
    plt.colorbar()
    plt.show()
    valid_mask = target != 0
    plt.imshow(valid_mask[l].detach().cpu().numpy())
    plt.title(f"Valid mask")
    plt.colorbar()
    plt.show()
    
    break
    


            

    
    
#torch.save(model.state_dict(), "model_200e_ce_new_dataset_3class.pth")

In [None]:
torch.save(model.state_dict(), "model_200e_ce_new_dataset_4class.pth")

In [None]:
# Plotting the loss graphs
plt.figure(figsize=(10, 6))
plt.plot(range(1, n_epochs + 1), train_losses, label="Train Loss")
plt.plot(range(1, n_epochs + 1), test_losses, label="Test Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Training and Evaluation Loss Over Epochs")
plt.legend()
plt.show()

# Plotting the accuracy graphs
plt.figure(figsize=(10, 6))
plt.plot(range(1, n_epochs + 1), train_accuracies, label="Train Accuracy")
plt.plot(range(1, n_epochs + 1), test_accuracies, label="Test Accuracy")
plt.xlabel("Epoch")
plt.ylabel("Accuracy")
plt.title("Training and Evaluation Accuracy Over Epochs")
plt.legend()
plt.show()





In [None]:
import random

rs = [r for r in random.sample(range(len(dataset_test)), 6)]
print(rs)

cols = 6
rows = 2
fig = plt.figure(figsize=(24, 8))
i = 1
c = 0

while i <= rows * cols:
    r = random.randint(0, len(dataset_test))
    sat = dataset_test[rs[c]]["satellite"].permute(1, 2, 0)
    fig.add_subplot(rows, cols, i)
    plt.imshow(sat)
    plt.axis("off")
    
    i += 1
    
    fig.add_subplot(rows, cols, i)
    
    s = dataset_test[rs[c]]["satellite"].to(device)
    #print(s.shape)
    s = s.unsqueeze(0)
    output = model(s)
    #output = torch.argmax(output, dim=1)
    
    output = output[0][0].detach().cpu().numpy()
    
    plt.imshow(output, cmap="inferno")
    plt.colorbar()
    plt.axis("off")
    
    i += 1
    c += 1

fig.tight_layout()

#fig.savefig(f"ckpt/{ck}.png")
plt.show()





## Combine all three paths in one with proper labels

In [None]:
dataset_dir = "../../dataset/dataset/train"
img_transform = ToTensor()
path_transform = ToTensor()
dataset_train = IntersectionDatasetClasses(root_dir=dataset_dir, 
                                    transform=img_transform,
                                    path_transform=path_transform)

dataset_dir = "../../dataset/dataset/test"
img_transform = ToTensor()
path_transform = ToTensor()
dataset_test = IntersectionDatasetClasses(root_dir=dataset_dir,
                                   transform=img_transform,
                                   path_transform=path_transform)
print(len(dataset_train))
print(len(dataset_test))

num_workers = multiprocessing.cpu_count()
b = 4

train_dataloader = DataLoader(dataset_train, batch_size=b, shuffle=True, num_workers=num_workers, pin_memory=True, persistent_workers=True, collate_fn=custom_collate_fn)
test_dataloader = DataLoader(dataset_test, batch_size=b, shuffle=True, num_workers=num_workers, pin_memory=True, persistent_workers=True, collate_fn=custom_collate_fn)

#os.environ["CUDA_LAUNCH_BLOCKING"] = "0"
def init_weights(m):
    if isinstance(m, nn.Conv2d) and m.kernel_size == (1, 1):
        nn.init.normal_(m.weight, mean=0, std=0.001)
        if m.bias is not None:
            nn.init.constant_(m.bias, 0)
    elif isinstance(m, nn.Conv2d):
        nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
        if m.bias is not None:
            nn.init.constant_(m.bias, 0)
    elif isinstance(m, nn.BatchNorm2d):
        nn.init.constant_(m.weight, 1)
        nn.init.constant_(m.bias, 0)
    else:
        pass



In [None]:
try: 
    del(model)
except NameError:
    pass
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = u.UNet(n_channels=3, n_classes=5).to(device) # background, left, right, ahead, stacked
model.apply(init_weights)

try:
    del(optimizer)
except NameError:
    pass
optimizer = Adam(model.parameters(), lr=1e-4)

try:
    del(scheduler)
except NameError:
    pass
scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=20, T_mult=1, eta_min=1e-5)

try:
    del(scaler)
except NameError:
    pass
scaler = GradScaler(device=device)

class_counts = torch.tensor([152000, 2000, 2000, 2000, 2000], dtype=torch.float)
weights = 1.0 / class_counts
weights = weights / weights.sum()
lb = torch.nn.CrossEntropyLoss(weight=weights.to(device))
lb = lb.to(device)

In [None]:
for batch in train_dataloader:
    satellite = batch["satellite"].to(device)
    path_line = batch["paths"]
    
    plt.imshow(satellite[0].permute(1, 2, 0).cpu().numpy())
    plt.axis("off")
    plt.show()
    
    print(path_line[0][0]["path_line"].shape)
    l = len(path_line[0])
    
    B = len(path_line)  # Number of batches
    _, H, W = path_line[0][0]["path_line"].shape  # Height and width of the path_line
    combined = torch.full((B, H, W), 0, dtype=torch.long)
    
    for i in range(l):
        p = path_line[0][i]["path_line"]
        ee = path_line[0][i]["ee_data"]
        exit_x = ee["exit"]["x"]
        exit_y = ee["exit"]["y"]
        
        class_label = map_exit_to_class(exit_x, exit_y)
        
        mask = p[0] > 0
        
        class_label = torch.full((B, H, W), class_label, dtype=torch.long)
        
        combined += mask * class_label
    combined = combined.clamp(0, 4).to(device)
        
        # combined += path_line[0][i]["path_line"][0] * class_label
        # combined = combined.clamp(0, 3)
        
    
    plt.imshow(combined[0].cpu().numpy(), cmap="inferno")
    plt.colorbar()
    plt.axis("off")
    
    break

In [None]:
print("Passing through the model")
output = model(satellite)
print("Output shape:", output.shape)  # Expected: torch.Size([1, 1, 400, 400])
L_ce = lb(output, combined)
print("Loss calculation")
loss = L_ce
loss.item()





In [None]:
output[0].shape

plt.imshow(output[0][0].detach().cpu().numpy(), cmap="inferno")
plt.colorbar()
plt.axis("off")
plt.show()



In [None]:
from tqdm.notebook import tqdm

train_losses = []
test_losses = []
train_accuracies = []
test_accuracies = []

n_epochs = 500
alpha = 0.5
epochs = tqdm(range(n_epochs), desc="Training", unit=" epoch")


#os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

for epoch in epochs:
    model.train()
    running_train_loss = 0.0
    running_train_correct = 0
    running_train_total = 0
    
    batches = tqdm(train_dataloader, desc="Batches", unit=" batch", leave=False)
    
    for batch in batches:
        satellite = batch["satellite"].to(device, non_blocking=True)
        class_labels = batch["class_labels"].to(device, non_blocking=True)
        #path_line = batch["paths"]
        
        class_labels = class_labels.squeeze(1)
        
        optimizer.zero_grad()
        
        with autocast("cuda"):
            output = model(satellite)
            loss = lb(output, class_labels)
        
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        
        running_train_loss += loss.item()
        
        p = torch.argmax(output, dim=1)
        running_train_correct += (p == class_labels).sum().item()
        running_train_total += class_labels.size(0)
        
        batches.set_postfix({"Loss": loss.item()})
        
        
    avg_train_loss = running_train_loss / len(train_dataloader)
    train_losses.append(avg_train_loss)
    train_accuracy = running_train_correct / running_train_total
    train_accuracies.append(train_accuracy)
    
    model.eval()
    running_test_loss = 0.0
    running_test_correct = 0
    running_test_total = 0
    
    test_batches = tqdm(test_dataloader, desc="Batches", unit=" batch", leave=False)
    with torch.no_grad():
        for batch in test_batches:
            satellite = batch["satellite"].to(device, non_blocking=True)
            class_labels = batch["class_labels"].to(device, non_blocking=True)
            #path_line = batch["paths"]
            
            class_labels = class_labels.squeeze(1)
            
            with autocast("cuda"):
                output = model(satellite)
                loss = lb(output, class_labels)
            
            running_test_loss += loss.item()
            
            p = torch.argmax(output, dim=1)
            running_test_correct += (p == class_labels).sum().item()
            running_test_total += class_labels.size(0)
            
            test_batches.set_postfix({"Loss": loss.item()})
            
    test_batches.close()
            
    avg_test_loss = running_test_loss / len(test_dataloader)
    test_losses.append(avg_test_loss)
    
    test_accuracy = running_test_correct / running_test_total
    test_accuracies.append(test_accuracy)
    
    scheduler.step()  
    
    if (epoch + 1) % 50 == 0:
        checkpoint = {
            'epoch': epoch + 1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'train_losses': train_losses,
            'test_losses': test_losses,
            'train_accuracies': train_accuracies,
            'test_accuracies': test_accuracies,
        }
        os.makedirs('./ckpt', exist_ok=True)
        torch.save(checkpoint, f'./ckpt/bce_checkpoint_epoch_{epoch + 1}_5classes.pth')
        
    epochs.set_postfix({"Train Loss": avg_train_loss, "Test Loss": avg_test_loss, "Train Accuracy": train_accuracy, "Test Accuracy": test_accuracy})
    batches.close()
    
epochs.close()

    
#torch.save(model.state_dict(), "model_200e_ce_new_dataset_3class.pth")







In [None]:
print("Passing through the model")
output = model(satellite)
print("Output shape:", output.shape)  # Expected: torch.Size([1, 1, 400, 400])
L_ce = lb(output, combined)
print("Loss calculation")
loss = L_ce
loss.item()


plt.imshow(output[0][0].detach().cpu().numpy(), cmap="inferno")
plt.colorbar()
plt.axis("off")
plt.show()

