In [1]:
# temporary solution to crashing
import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'

In [2]:
import torch
import torch.nn as nn
from model_helper import MobileNetV2
import wandb
import mysql.connector as connector
from pathlib import Path
import tqdm
import torchvision
import torch.nn.functional as F

In [3]:
home = os.path.expanduser('~')
os.chdir(home) # b/c we will be using universal paths

host = '127.0.0.1'
user = 'root' # change to your username
password = 'vasya1' # change to your password
database = 'ai_proj_2025' # we should all have this as the db name 

try:
    conn = connector.connect(
        host = host, 
        user = user, 
        password = password, 
        database = database
    )
    print('success')
except connector.Error as err:
    print(err)

success


In [4]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device {device}')

Using device cuda


In [5]:
# research_dir = Path(home, 'Desktop', 'Education', 'Spring 2025', 'AI', 'research')
research_dir = Path(home, "ai_research_proj_spring_2025", "research_ai_class_spring_2025") # in lab 409 computer for Agafia
os.chdir(research_dir)

from data_helper import SQLDataset_Informative

os.chdir(home)

In [6]:
from torchvision.transforms import v2

# transforms
transformations = v2.Compose([
    v2.RandomResizedCrop(size=(224, 224), antialias=True), 
    v2.RandomHorizontalFlip(p=0.5),
    v2.ToDtype(torch.float32, scale=True),
    v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

In [7]:
# create train, val, and test sets

data_dir=Path('OneDrive - Stephen F. Austin State University', 'CrisisMMD_v2.0','CrisisMMD_v2.0')

train_set = SQLDataset_Informative(conn=conn, img_col='image_path', label_col='image_info', transform=transformations, 
                     data_dir=data_dir, is_train=True)
val_set = SQLDataset_Informative(conn=conn, img_col='image_path', label_col='image_info', transform=transformations, 
                     data_dir=data_dir, is_val=True)
test_set = SQLDataset_Informative(conn=conn, img_col='image_path', label_col='image_info', transform=transformations, 
                     data_dir=data_dir, is_test=True)

In [8]:
from torch.utils.data import DataLoader

train_loader = DataLoader(train_set, batch_size=128)
val_loader = DataLoader(val_set, batch_size=128)
test_loader = DataLoader(test_set, batch_size=128)

In [9]:
# importing accuracy metric functions
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score

**IMPORTANT**: Make sure to change the run_name (and 'architecture' parameter of the wandb `run` variable if necessary) with each new run. 

In [10]:
# validation fn

def dev(model, val_loader):
    model.to(device)
    batch_size = val_loader.batch_size
    avg = 'macro' # used when computing certain accuracy metrics
    model.eval()

    eval_loss = 0

    all_preds = []
    all_trues = []

    with torch.no_grad():
        for b, batch in tqdm.tqdm(enumerate(val_loader), 
                             total= len(val_loader), desc=f"Processing validation data"):
            images = batch['image'].to(device)
            labels = batch['label'].to(device)

            raw_logits = model.forward(images)

            preds = torch.argmax(raw_logits, dim=1) # https://discuss.pytorch.org/t/cross-entropy-loss-get-predicted-class/58215

            loss = nn.CrossEntropyLoss()(raw_logits, labels)

            eval_loss += loss.item()

            all_preds.extend(preds.tolist())
            all_trues.extend(labels.tolist())


        # metrics 
        acc_total = accuracy_score(y_true=all_trues, y_pred=all_preds)
        precision = precision_score(y_true=all_trues, y_pred=all_preds, zero_division=0, average=avg)
        recall = recall_score(y_true=all_trues, y_pred=all_preds, zero_division=0, average=avg)
        f1 = f1_score(y_true=all_trues, y_pred=all_preds, zero_division=0, average=avg)

        avg_eval_loss = eval_loss / (len(val_loader))

        metrics = {
            'accuracy': acc_total, 
            'precision': precision, 
            'recall': recall, 
            'f1': f1, 
            'avg_eval_loss': avg_eval_loss
        }
        wandb.log(metrics)
        print('****Evaluation****')
        print(f'total_accuracy: {acc_total}')

        return acc_total


In [None]:
def train_eval(model, num_epochs, run_name, lr, architecture):
    # training hyperparameters & functions/tools
    lr = lr 
    num_epochs = num_epochs
    run_name = run_name
    

    best_val_acc = 0.0
    optimizer = torch.optim.Adam(model.parameters(), lr=lr) #stocastic gradient descent for our optimization algorithm

    model.to(device)
    # for saving the models
    Path(research_dir, 'models' ).mkdir(parents=True, exist_ok=True)

    # before training, set up wandb for tracking purposes
    os.environ["WANDB_API_KEY"] = "5a08d1ebbf0e86ab877a128b98be3c320301b6a0"

    run = wandb.init(
        # Set the wandb entity where your project will be logged (generally your team name).
        entity="agafiabschool-stephen-f-austin-state-university",
        # Set the wandb project where this run will be logged.
        project="Research Project for CSCI-1465",
        # Track hyperparameters and run metadata.
        config={
            "learning_rate": lr,
            "architecture": architecture,
            "dataset": "CrisisMMD",
            "epochs": num_epochs,
        }, name=run_name
    )


    for epoch in range(num_epochs):
        print(f'Epoch {epoch+1}')
        wandb.log({'epoch': epoch+1})

        loss = 0

        for b, batch in tqdm.tqdm(enumerate(train_loader), 
                            total= len(train_loader), desc=f"Processing training data in epoch {epoch+1}"):
            model.train()
            images = batch['image'].to(device)
            labels = batch['label'].to(device)

            model.zero_grad() 
            optimizer.zero_grad()

            # forward pass
            raw_logits = model.forward(images)
            # loss - can use raw logits for this function b/c it applies LogSoftmax 
            loss = nn.CrossEntropyLoss()(raw_logits, labels)

            wandb.log({'Train Loss': loss.item()})

            # backprop!
            loss.backward()

            optimizer.step()

            if (b+1) % 10 == 0:
                print(f'batch: {b+1} ; loss: {loss.item()}')
        
        # each epoch, run validation
        acc = dev(model=model, val_loader=val_loader)
        
        if acc > best_val_acc:
            best_val_acc = acc
            torch.save(model, Path(research_dir, 'models', f'{run_name}'))
    
    return best_val_acc

In [None]:
# time to train
num_epochs = 1000
os.chdir(home)

# for grid-search
history = []
for lr in [10**-4, 30**-4, 10**-3, 30**-3, 10**-2, 30**-2, 10**-1, 30**-1, 1]:
    run_name = f'MobileNetV2 Test lr={lr}'
    my_dict = {}
    my_dict['lr'] = lr
    model = MobileNetV2()
    model = nn.Sequential(model, nn.Linear(1280, 2)) # 1280 is the num of features outputted by MobileNetv2 after flattening
    my_dict['acc'] = train_eval(model, 1000, run_name, lr, 'MobileNetV2')
    history.append(my_dict)

wandb: Currently logged in as: agafiabschool (agafiabschool-stephen-f-austin-state-university) to https://api.wandb.ai. Use `wandb login --relogin` to force relogin
wandb: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


Epoch 1


Processing training data in epoch 1:   7%|▋         | 9/128 [00:04<00:54,  2.17it/s]

batch: 10 ; loss: 0.7034508585929871


Processing training data in epoch 1:  15%|█▍        | 19/128 [00:09<00:51,  2.11it/s]

batch: 20 ; loss: 0.6978794932365417


Processing training data in epoch 1:  23%|██▎       | 29/128 [00:14<00:47,  2.07it/s]

batch: 30 ; loss: 0.7482956647872925


Processing training data in epoch 1:  30%|███       | 39/128 [00:19<00:43,  2.03it/s]

batch: 40 ; loss: 0.6774151921272278


Processing training data in epoch 1:  38%|███▊      | 49/128 [00:24<00:39,  2.02it/s]

batch: 50 ; loss: 0.6924484372138977


Processing training data in epoch 1:  46%|████▌     | 59/128 [00:29<00:34,  1.99it/s]

batch: 60 ; loss: 0.7026772499084473


Processing training data in epoch 1:  54%|█████▍    | 69/128 [00:34<00:29,  1.98it/s]

batch: 70 ; loss: 0.6997445225715637


Processing training data in epoch 1:  62%|██████▏   | 79/128 [00:39<00:23,  2.06it/s]

batch: 80 ; loss: 0.6904569864273071


Processing training data in epoch 1:  70%|██████▉   | 89/128 [00:44<00:19,  2.04it/s]

batch: 90 ; loss: 0.8762502670288086


Processing training data in epoch 1:  77%|███████▋  | 99/128 [00:49<00:14,  2.05it/s]

batch: 100 ; loss: 0.7036304473876953


Processing training data in epoch 1:  85%|████████▌ | 109/128 [00:54<00:08,  2.14it/s]

batch: 110 ; loss: 0.7846125960350037


Processing training data in epoch 1:  93%|█████████▎| 119/128 [00:59<00:04,  2.18it/s]

batch: 120 ; loss: 0.724918782711029


Processing training data in epoch 1: 100%|██████████| 128/128 [01:03<00:00,  2.03it/s]
Processing validation data: 100%|██████████| 8/8 [00:03<00:00,  2.18it/s]


****Evaluation****
total_accuracy: 0.5298672566371682
Epoch 2


Processing training data in epoch 2:   7%|▋         | 9/128 [00:04<00:53,  2.23it/s]

batch: 10 ; loss: 0.7072049975395203


Processing training data in epoch 2:  15%|█▍        | 19/128 [00:08<00:49,  2.19it/s]

batch: 20 ; loss: 0.6972578763961792


Processing training data in epoch 2:  23%|██▎       | 29/128 [00:13<00:46,  2.12it/s]

batch: 30 ; loss: 0.7856351137161255


Processing training data in epoch 2:  30%|███       | 39/128 [00:18<00:42,  2.11it/s]

batch: 40 ; loss: 0.7220056056976318


Processing training data in epoch 2:  38%|███▊      | 49/128 [00:23<00:37,  2.12it/s]

batch: 50 ; loss: 0.6894674897193909


Processing training data in epoch 2:  46%|████▌     | 59/128 [00:28<00:33,  2.07it/s]

batch: 60 ; loss: 0.7283859252929688


Processing training data in epoch 2:  54%|█████▍    | 69/128 [00:32<00:28,  2.09it/s]

batch: 70 ; loss: 0.6851937174797058


Processing training data in epoch 2:  62%|██████▏   | 79/128 [00:37<00:22,  2.13it/s]

batch: 80 ; loss: 0.7203905582427979


Processing training data in epoch 2:  70%|██████▉   | 89/128 [00:42<00:18,  2.12it/s]

batch: 90 ; loss: 0.805216372013092


Processing training data in epoch 2:  77%|███████▋  | 99/128 [00:47<00:13,  2.15it/s]

batch: 100 ; loss: 0.7607424855232239


Processing training data in epoch 2:  85%|████████▌ | 109/128 [00:51<00:08,  2.27it/s]

batch: 110 ; loss: 0.7887974381446838


Processing training data in epoch 2:  93%|█████████▎| 119/128 [00:56<00:03,  2.26it/s]

batch: 120 ; loss: 0.6984427571296692


Processing training data in epoch 2: 100%|██████████| 128/128 [01:00<00:00,  2.12it/s]
Processing validation data: 100%|██████████| 8/8 [00:03<00:00,  2.44it/s]


****Evaluation****
total_accuracy: 0.5110619469026548
Epoch 3


Processing training data in epoch 3:   7%|▋         | 9/128 [00:03<00:51,  2.29it/s]

batch: 10 ; loss: 0.7126647233963013


Processing training data in epoch 3:  15%|█▍        | 19/128 [00:08<00:48,  2.23it/s]

batch: 20 ; loss: 0.6838854551315308


Processing training data in epoch 3:  23%|██▎       | 29/128 [00:13<00:46,  2.14it/s]

batch: 30 ; loss: 0.7848340272903442


Processing training data in epoch 3:  30%|███       | 39/128 [00:18<00:42,  2.11it/s]

batch: 40 ; loss: 0.749981164932251


Processing training data in epoch 3:  38%|███▊      | 49/128 [00:22<00:37,  2.11it/s]

batch: 50 ; loss: 0.685023844242096


Processing training data in epoch 3:  46%|████▌     | 59/128 [00:27<00:33,  2.08it/s]

batch: 60 ; loss: 0.6947875618934631


Processing training data in epoch 3:  54%|█████▍    | 69/128 [00:32<00:28,  2.10it/s]

batch: 70 ; loss: 0.6845302581787109


Processing training data in epoch 3:  62%|██████▏   | 79/128 [00:37<00:22,  2.13it/s]

batch: 80 ; loss: 0.6901729106903076


Processing training data in epoch 3:  70%|██████▉   | 89/128 [00:42<00:18,  2.13it/s]

batch: 90 ; loss: 0.8098732233047485


Processing training data in epoch 3:  77%|███████▋  | 99/128 [00:47<00:13,  2.15it/s]

batch: 100 ; loss: 0.7150601744651794


Processing training data in epoch 3:  85%|████████▌ | 109/128 [00:51<00:08,  2.26it/s]

batch: 110 ; loss: 0.7601787447929382


Processing training data in epoch 3:  93%|█████████▎| 119/128 [00:56<00:04,  2.12it/s]

batch: 120 ; loss: 0.7162261009216309


Processing training data in epoch 3: 100%|██████████| 128/128 [01:00<00:00,  2.12it/s]
Processing validation data: 100%|██████████| 8/8 [00:03<00:00,  2.45it/s]


****Evaluation****
total_accuracy: 0.5320796460176991
Epoch 4


Processing training data in epoch 4:   5%|▍         | 6/128 [00:03<01:03,  1.92it/s]


KeyboardInterrupt: 