In [None]:
# temporary solution to crashing
import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'

In [None]:
import torch
import torch.nn as nn
from model_helper import MobileNetV2, Inv2d
import wandb
import mysql.connector as connector
from pathlib import Path
import tqdm
import torchvision
import torch.nn.functional as F
from torch.optim.lr_scheduler import MultiStepLR
from torchvision.models import mobilenet_v2

In [None]:
home = os.path.expanduser('~')
os.chdir(home) # b/c we will be using universal paths

host = '127.0.0.1'
user = 'root' # change to your username
password = 'vasya1' # change to your password
database = 'ai_proj_2025' # we should all have this as the db name 

try:
    conn = connector.connect(
        host = host, 
        user = user, 
        password = password, 
        database = database
    )
    print('success')
except connector.Error as err:
    print(err)

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device {device}')

In [None]:
# research_dir = Path(home, 'Desktop', 'Education', 'Spring 2025', 'AI', 'research')
research_dir = Path(home, "ai_research_proj_spring_2025", "research_ai_class_spring_2025") # in lab 409 computer for Agafia
os.chdir(research_dir)

from data_helper import SQLDataset_Informative

os.chdir(home)

In [None]:
from torchvision.transforms import v2

# transforms
transformations = v2.Compose([
    v2.RandomResizedCrop(size=(224, 224), antialias=True), 
    v2.RandomHorizontalFlip(p=0.5),
    v2.ToDtype(torch.float32, scale=True),
    v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

In [None]:
# create train, val, and test sets

data_dir=Path('OneDrive - Stephen F. Austin State University', 'CrisisMMD_v2.0','CrisisMMD_v2.0')

train_set = SQLDataset_Informative(conn=conn, img_col='image_path', label_col='image_info', transform=transformations, 
                     data_dir=data_dir, is_train=True, table_name='Hurricane_Images')
val_set = SQLDataset_Informative(conn=conn, img_col='image_path', label_col='image_info', transform=transformations, 
                     data_dir=data_dir, is_val=True, table_name='Hurricane_Images')
test_set = SQLDataset_Informative(conn=conn, img_col='image_path', label_col='image_info', transform=transformations, 
                     data_dir=data_dir, is_test=True, table_name='Hurricane_Images')

In [None]:
train_set_2points = [train_set.__getitem__(i) for i in range(2)]
val_set_2points = [val_set.__getitem__(i) for i in range(2)]

In [None]:
from torch.utils.data import DataLoader

train_loader = DataLoader(train_set, batch_size=256)
val_loader = DataLoader(val_set, batch_size=128)
test_loader = DataLoader(test_set, batch_size=128)

# for data in train_loader:
#     print(data['label'])

In [None]:
# importing accuracy metric functions
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score

**IMPORTANT**: Make sure to change the run_name (and 'architecture' parameter of the wandb `run` variable if necessary) with each new run. 

In [None]:
# validation fn

def dev(model, val_loader):
    model.to(device)
    batch_size = val_loader.batch_size
    avg = 'macro' # used when computing certain accuracy metrics
    model.eval()

    eval_loss = 0

    all_preds = []
    all_trues = []

    with torch.no_grad():
        for b, batch in tqdm.tqdm(enumerate(val_loader), 
                             total= len(val_loader), desc=f"Processing validation data"):
            images = batch['image'].to(device)
            labels = batch['label'].to(device)

            raw_logits = model.forward(images)

            preds = torch.argmax(raw_logits, dim=1) # https://discuss.pytorch.org/t/cross-entropy-loss-get-predicted-class/58215

            loss = nn.CrossEntropyLoss()(raw_logits, labels)

            eval_loss += loss.item()

            all_preds.extend(preds.tolist())
            all_trues.extend(labels.tolist())


        # metrics 
        acc_total = accuracy_score(y_true=all_trues, y_pred=all_preds)
        precision = precision_score(y_true=all_trues, y_pred=all_preds, zero_division=0, average=avg)
        recall = recall_score(y_true=all_trues, y_pred=all_preds, zero_division=0, average=avg)
        f1 = f1_score(y_true=all_trues, y_pred=all_preds, zero_division=0, average=avg)

        avg_eval_loss = eval_loss / (len(val_loader))

        metrics = {
            'accuracy': acc_total, 
            'precision': precision, 
            'recall': recall, 
            'f1': f1, 
            'avg_eval_loss': avg_eval_loss
        }
        wandb.log(metrics)
        print('****Evaluation****')
        print(f'total_accuracy: {acc_total}')

        return acc_total
    


In [None]:
def train_eval(model, num_epochs, run_name, lr, architecture, frozen_layers, dataset='CrisisMMD'):
    # training hyperparameters & functions/tools
    lr = lr 
    num_epochs = num_epochs
    run_name = run_name
    

    best_val_acc = 0.0
    optimizer = torch.optim.Adam(model.parameters(), lr=lr) #stocastic gradient descent for our optimization algorithm
    lr_sched = MultiStepLR(optimizer=optimizer, milestones=list(range(50, num_epochs, 30)), gamma=.1)

    model.to(device)
    # for saving the models
    Path(research_dir, 'models' ).mkdir(parents=True, exist_ok=True)

    # before training, set up wandb for tracking purposes
    os.environ["WANDB_API_KEY"] = "5a08d1ebbf0e86ab877a128b98be3c320301b6a0"

    run = wandb.init(
        # Set the wandb entity where your project will be logged (generally your team name).
        entity="agafiabschool-stephen-f-austin-state-university",
        # Set the wandb project where this run will be logged.
        project="Research Project for CSCI-1465",
        # Track hyperparameters and run metadata.
        config={
            "learning_rate": lr,
            "architecture": architecture,
            "dataset": dataset,
            "epochs": num_epochs,
            'frozen_layers': frozen_layers,
        }, name=run_name
    )


    for epoch in range(num_epochs):
        print(f'Epoch {epoch+1}')
        wandb.log({'epoch': epoch+1})

        loss = 0

        for b, batch in tqdm.tqdm(enumerate(train_loader), 
                            total= len(train_loader), desc=f"Processing training data in epoch {epoch+1}"):
            model.train()
            images = batch['image'].to(device)
            labels = batch['label'].to(device)

            model.zero_grad() 
            optimizer.zero_grad()

            # forward pass
            raw_logits = model.forward(images)
            # loss - can use raw logits for this function b/c it applies LogSoftmax 

            class_weights = torch.tensor([1.0, 1.5]).to(device)
            loss = nn.CrossEntropyLoss(weight=class_weights)(raw_logits, labels)
            print(f'Train Loss: {loss}')
            wandb.log({'Train Loss': loss.item()})
            wandb.log({'LR': lr_sched.get_last_lr()[0]})


            # backprop!
            loss.backward()

            optimizer.step()

            if (b+1) % 20 == 0:
                print(f'batch: {b+1} ; loss: {loss.item()}')
        
        # each epoch, run validation
        acc = dev(model=model, val_loader=val_loader)
        
        if acc > best_val_acc:
            best_val_acc = acc
            torch.save(model, Path(research_dir, 'models', f'{run_name}'))
        
        lr_sched.step()
    
    return best_val_acc

In [None]:
# time to train
num_epochs = 200
os.chdir(home)

# for grid-search
history = []
# for lr in [10**-4, 30**-4, 10**-3, 30**-3, 10**-2, 30**-2, 10**-1, 30**-1, 1]:
for lr in [10**-2, 10**-1, 30**-1, 1]:
    # run_name = f'MobileNetV2 lr={lr}'
    run_name = f'MobileNetV2-extra fc - class_weights - 6 layers frozen - lr={lr}'
    my_dict = {}
    my_dict['lr'] = lr

    # instantiate our model
    model = MobileNetV2()
    

    # load pretrained weights
    pretrained = mobilenet_v2(weights='IMAGENET1K_V2')
    weights = pretrained.state_dict()
    model.features.load_state_dict(weights, strict=False)

    # turn off all but the topmost layers
    freeze_up_to = 9
    for param in model.features[:freeze_up_to].parameters(): 
        param.requires_grad = False
    
    model = nn.Sequential(model, nn.Linear(1280, 512), nn.ReLU(), nn.Linear(512, 128), nn.ReLU(), nn.Linear(128, 2)) # 1280 is the num of features outputted by MobileNetv2 after flattening
    
    acc = train_eval(model, num_epochs, run_name, lr=lr, architecture='MobileNetV2', frozen_layers=freeze_up_to, dataset='CrisisMMD - Hurricane Images')
    my_dict['acc'] = acc
    history.append(my_dict)