### 1. load the dataset

In [1]:
def inverse_normalize(normalized, min_val, max_val):
    return ((normalized + 1) / 2) * (max_val - min_val) + min_val

In [2]:
from torch.utils import data
from JGC_MMN_dataloader import JGC_MMN_dataloader

In [3]:
train_dataset = JGC_MMN_dataloader(
        name=["long_term.npy", "short_term.npy", "ingredients.npy", "future.npy", "label.npy"],
        root="/Users/gunneo/Documents/4_2/Graduation_Thesis/Datasets/Beijing_House_Price_Dataset/JGC_MMN/train",
    )
train_dataloader = data.DataLoader(train_dataset, batch_size=36, shuffle=True)
img_h = train_dataloader.dataset[0][0].shape[1]
img_w = train_dataloader.dataset[0][0].shape[2]
long_term_in_channels = train_dataloader.dataset[0][0].shape[0]
short_term_in_channels = train_dataloader.dataset[0][1].shape[0]
cur_ingred_dim = train_dataloader.dataset[0][2].shape[0]

print(img_h, img_w, long_term_in_channels, short_term_in_channels, cur_ingred_dim)

30 30 2 12 10


### 2. initialize the model

In [4]:
from JGC_MMN import JGC_MMN

In [5]:
# train_model = JGC_MMN(
#     branches = [0, 1, 2],
#     # img config
#     img_h = img_h,
#     img_w = img_w,
#     # short term config
#     short_term_in_channels = short_term_in_channels,
#     short_term_growth_rate = 4,
#     short_term_block_config = (9, 9, 9, 9, 9, 9),
#     # long term config
#     long_term_in_channels = long_term_in_channels,
#     long_term_growth_rate = 4,
#     long_term_block_config = (9, 9, 9, 9, 9, 9),
#     # cur ingredient config
#     cur_ingred_dim = cur_ingred_dim,
#     emedding_dim = 64,
#     # future price growth expectation config
#     # fusion config
#     modalities = 3,
#     indexs = [0, 1],
# )

### 3. initialize train process

In [6]:
import os
import torch
import torch.nn as nn
from torch.optim import Adam, SGD, RMSprop, Adagrad

def rmse_loss(pred, target):
    mse = nn.MSELoss()
    return torch.sqrt(mse(pred, target))

def train(
        model,
        train_dataloader,
        w_path,
        num_epochs=50,
        optimizer_type='adam',
        lr=5e-4,
        input_indices=[0, 1, 2, 3],
        patience=5,
        save_interval=25,
        loss_threshold=1e-4
        ):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)
        
    optimizers = {
        'adam': Adam(model.parameters(), lr=lr),
        'sgd': SGD(model.parameters(), lr=lr, momentum=0.9),
        'rmsprop': RMSprop(model.parameters(), lr=lr),
        'adagrad': Adagrad(model.parameters(), lr=lr)
    }
    if optimizer_type not in optimizers:
        raise ValueError("Unsupported optimizer type")
    optimizer = optimizers[optimizer_type]

    best_model_path = os.path.join(w_path, "best_model.pth")
    best_loss = float('inf')
    no_improve = 0

    for epoch in range(num_epochs):
        model.train()
        epoch_train_loss = 0.0
        
        for batch in train_dataloader:
            inputs = [batch[i].to(device) for i in input_indices]
            labels = batch[-1].to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = rmse_loss(outputs, labels)
            loss.backward()
            optimizer.step()
            
            epoch_train_loss += loss.item() * labels.size(0)
        
        epoch_train_loss /= len(train_dataloader.dataset)
        print(f"Epoch [{epoch+1}/{num_epochs}] Train Loss: {epoch_train_loss:.4f}")

        if abs(epoch_train_loss - best_loss) >= loss_threshold:
            best_loss = epoch_train_loss
            torch.save(model.state_dict(), best_model_path)
            print(f"Best model saved to {best_model_path} with loss {best_loss:.4f}")
            no_improve = 0
        else:
            no_improve += 1
            print(f"No improvement for {no_improve}/{patience} epochs")
            if no_improve > patience:
                print(f"Early stopping at epoch {epoch+1}")
                break

        if save_interval > 0 and (epoch + 1) % save_interval == 0:
            checkpoint_path = os.path.join(w_path, f"checkpoint_{epoch+1}.pth")
            torch.save(model.state_dict(), checkpoint_path)
            print(f"Checkpoint saved at {checkpoint_path}")

    print("Training complete.")

In [7]:
# train(
#     train_model,
#     train_dataloader,
#     w_path="/Users/gunneo/Documents/4_2/Graduation_Thesis/Datasets/Beijing_House_Price_Dataset/JGC_MMN/train_weights/",
#     num_epochs=500,
#     optimizer_type='adam',
#     lr=5e-3,
#     input_indices=[0, 1, 2],
#     patience=500,
#     save_interval=25,
#     loss_threshold=5e-4
# )

In [8]:
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# criterion = nn.MSELoss()
# tanh = nn.Tanh()
# num_batches = 0
# rmse_total = 0.0
# with torch.no_grad():
#     for batch in train_dataloader:
#         inputs = [batch[i].to(device) for i in [0, 1, 2]]
#         labels = batch[-1].to(device)
        
#         predictions = train_model(inputs)
#         # convert the normalized prediction back to the original scale
#         print(predictions[0][0][14])
#         print(labels[0][0][14])
#         mse = criterion(predictions, tanh(labels))
#         rmse = torch.sqrt(mse)
#         rmse_total += rmse.item()
#         num_batches += 1

# average_rmse = rmse_total / num_batches
# print(f"Average RMSE: {average_rmse}")

### 4. test the result

1. global min-max normalization, indexs = [0, 1, 2], trained for 290 epochs, lr = 1e-3, "adam" optimizer, patience = 5, final loss for training set is 0.1860, test RMSE[-1, 1]: 0.3251001834869385

2. global min-max normalization, indexs = [0, 1, 2], trained for 264 epochs, lr = 1e-3, "adam" optimizer, patience = 10, final loss for training set is 0.1346, test RMSE[-1, 1]: 0.269204705953598

3. global min-max normalization, indexs = [0, 1, 2], trained for ? epochs, lr = 5e-3, "adam" optimizer, patience = 5, final loss for training set is 0.0911, test RMSE[-1, 1]: 0.21775959432125092

4. global min-max normalization, indexs = [0, 1, 2], trained for 219 epochs, lr = 5e-3, "adam" optimizer, patience = 5, final loss for training set is 0.0892, test RMSE[-1, 1]: 0.21174463629722595

5. global min-max normalization, indexs = [0, 1, 2], trained for 500(full training) epochs, lr = 5e-3, "adam" optimizer, patience = 500, final loss for training set is 0.0775, test RMSE[-1, 1]: 0.22640728950500488

In [9]:
test_dataset = JGC_MMN_dataloader(
        name=["long_term.npy", "short_term.npy", "ingredients.npy", "future.npy", "label.npy"],
        root="/Users/gunneo/Documents/4_2/Graduation_Thesis/Datasets/Beijing_House_Price_Dataset/JGC_MMN/test",
    )
test_dataloader = data.DataLoader(test_dataset, batch_size=12, shuffle=True)
test_global_min = test_dataset.global_min
test_global_max = test_dataset.global_max

In [10]:
test_model = JGC_MMN(
    branches = [0, 1, 2],
    # img config
    img_h = img_h,
    img_w = img_w,
    # short term config
    short_term_in_channels = short_term_in_channels,
    short_term_growth_rate = 4,
    short_term_block_config = (9, 9, 9, 9, 9, 9),
    # long term config
    long_term_in_channels = long_term_in_channels,
    long_term_growth_rate = 4,
    long_term_block_config = (9, 9, 9, 9, 9, 9),
    # cur ingredient config
    cur_ingred_dim = cur_ingred_dim,
    emedding_dim = 64,
    # future price growth expectation config
    # fusion config
    modalities = 3,
    indexs = [0, 1],
)

In [11]:
test_model.load_state_dict(torch.load("/Users/gunneo/Documents/4_2/Graduation_Thesis/Datasets/Beijing_House_Price_Dataset/JGC_MMN//weights_bak/4/best_model.pth"))

  test_model.load_state_dict(torch.load("/Users/gunneo/Documents/4_2/Graduation_Thesis/Datasets/Beijing_House_Price_Dataset/JGC_MMN//weights_bak/4/best_model.pth"))


<All keys matched successfully>

In [12]:
def test(
    test_model,
    test_dataloader,
    input_indexs = [0, 1, 2, 3]
):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    test_model = test_model.to(device)
    num_batches = 0
    rmse_total = 0.0
    conversion_rmse_total = 0.0
    with torch.no_grad():
        for batch in test_dataloader:
            inputs = [batch[i].to(device) for i in input_indexs]
            labels = batch[-1].to(device)
            
            predictions = test_model(inputs)
            print("without conversion: ")
            print(predictions[0][0][14])
            print(labels[0][0][14])
            rmse = rmse_loss(predictions, labels)
            rmse_total += rmse.item()
            print('------------------------------------')
            print("with conversion: ")
            predictions = inverse_normalize(predictions, test_global_min[-1], test_global_max[-1])
            labels = inverse_normalize(labels, test_global_min[-1], test_global_max[-1])
            print(predictions[0][0][14])
            print(labels[0][0][14])
            conversion_rmse = rmse_loss(predictions, labels)
            conversion_rmse_total += conversion_rmse.item()
            num_batches += 1

    average_rmse = rmse_total / num_batches
    print(f"Average RMSE: {average_rmse}")
    average_conversion_rmse = conversion_rmse_total / num_batches
    print(f"Average Conversion RMSE: {average_conversion_rmse}")

In [13]:
test(test_model, test_dataloader, [0, 1, 2])

without conversion: 
tensor([-0.6639, -0.6287, -0.9998, -0.6655, -0.8170, -0.4762, -0.8754, -0.7860,
         0.0798, -0.0462,  0.2001,  0.1219,  0.1374, -0.1622, -0.4785, -0.1101,
         0.1910, -0.1151, -0.0517, -0.2776,  0.1169, -0.5404, -0.6748, -0.3030,
        -0.8263, -0.9657, -0.9259, -0.7056, -0.5707, -0.9909])
tensor([-1.0000, -0.4420, -1.0000, -0.4529, -1.0000, -0.1551, -1.0000, -1.0000,
         0.0600, -0.1786,  0.3060, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
         0.1627, -0.1500, -0.3536, -0.2716, -0.0039, -1.0000, -1.0000, -0.4350,
        -0.5826, -1.0000, -0.5114, -0.5083, -1.0000, -1.0000])
------------------------------------
with conversion: 
tensor([1.6507e+02, 1.8237e+02, 1.0452e-01, 1.6429e+02, 8.9871e+01, 2.5728e+02,
        6.1211e+01, 1.0511e+02, 5.3036e+02, 4.6850e+02, 5.8946e+02, 5.5109e+02,
        5.5870e+02, 4.1151e+02, 2.5614e+02, 4.3713e+02, 5.8501e+02, 4.3463e+02,
        4.6581e+02, 3.5486e+02, 5.4862e+02, 2.2574e+02, 1.5974e+02, 3.4236e+02