In [7]:
import numpy as np
import torch
from torch.utils.data import TensorDataset, DataLoader
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
from ANGP import ANG_P
from torch.utils.data import random_split
from torch.utils.data import ConcatDataset

def train_1(net, data_loader, opt, crit, num_epochs, len_dataset, device, name):
    net.train()
    best_loss1 = 100000
    
    # train output1
    for epoch in range(num_epochs):
        total_loss1 = 0
        for feature, labels in data_loader:
            opt.zero_grad()
            labels = labels.to(device)
            feature = feature.to(device)
            output1, _ = net(feature)
            output1 = output1.squeeze()
            loss1 = crit(output1, labels[:, 0])
            total_loss1 += loss1.item() * len(labels)
            loss1.backward()
            opt.step()
        if total_loss1/len_dataset < best_loss1:
            torch.save(net.state_dict(), name)
            best_loss1 = total_loss1/len_dataset
        print("epoch{}:mse1:{}".format(epoch+1, total_loss1/len_dataset))

    print("best loss:{}".format(best_loss1))
    
def train_2(net, data_loader, opt, crit, num_epochs, len_dataset, device, name):
    net.train()
    best_loss2 = 100000
    for epoch in range(num_epochs):
        total_loss2 = 0
        for feature, labels in data_loader:
            labels = labels.to(device)
            feature = feature.to(device)
            opt.zero_grad()
            _, output2 = net(feature)
            output2 = output2.squeeze()
            loss2 = crit(output2, labels[:, 1])
            total_loss2 += loss2.item() * len(labels)
            loss2.backward()
            opt.step()
        if total_loss2/len_dataset < best_loss2:
            torch.save(net.state_dict(), name)
            best_loss2 = total_loss2/len_dataset
        print("epoch{}:mse2:{}".format(epoch+1, total_loss2/len_dataset))

    print("best loss:{}".format(best_loss2))


def test(net, data_loader, crit, len_dataset, device):
    net.eval()
    y1_pred = []
    y2_pred = []
    y1_target = []
    y2_target = []
    total_loss1 = 0
    total_loss2 = 0
    total_loss = 0
    with torch.no_grad():
        for feature, labels in data_loader:
            feature = feature.to(device)
            labels = labels.to(device)
            output1, output2 = net(feature)
            output1 = output1.squeeze()
            output2 = output2.squeeze()
            test_loss1 = crit(output1, labels[:, 0])
            test_loss2 = crit(output2, labels[:, 1])
            loss = crit(output1 + output2, (labels[:, 0]) + labels[:, 1])
            total_loss1 = total_loss1 + (test_loss1.item() * len(labels))
            total_loss2 = total_loss2 + (test_loss2.item() * len(labels))
            total_loss += (loss.item() * len(labels))
            
            y1_pred.append(output1.detach().cpu())
            y2_pred.append(output2.detach().cpu())
            y1_target.append(labels[:, 0].view(output1.shape).detach().cpu())
            y2_target.append(labels[:, 1].view(output2.shape).detach().cpu())
    
    y1_pred = np.array(y1_pred)
    y2_pred = np.array(y2_pred)
    y1_target = np.array(y1_target)
    y2_target = np.array(y2_target)
    y_pred = y1_pred+y2_pred
    y_target = y1_target+y2_target
    mse = 0
    mape = 0
    r2 = 0
    ms = 0
    mean = np.mean(y_target, axis=0)
    for i in range(len(y_pred)):
        mse += (y_pred[i]-y_target[i])**2
        mape += np.abs((y_pred[i]-y_target[i])/y_target[i])
        ms += (y_target[i]-mean)**2
    mse /= len(y_pred)
    rmse = np.sqrt(mse)
    mape /= len(y_pred)
    r2 = 1-mse/ms
    print("mse:{}".format(mse))
    print("rmse:{}".format(rmse))
    print("mape:{}".format(mape))
    print("r2:{}".format(r2))
    print("test_loss1:{}".format(total_loss1/len_dataset))
    print("test_loss2:{}".format(total_loss2/len_dataset))
    print("test_loss:{}".format(total_loss/len_dataset))


def Power_dataset(feature, label, norm_scale):
    load_feature = np.load(feature)
    load_label = np.load(label)
    reshape_feature = load_feature.reshape(load_feature.shape[0],load_feature.shape[1] * load_feature.shape[2])
    print(load_label[2])
    reshape_label = load_label[:, [1, 2]]
    
    # reshape_label[:, :] /= norm_scale
    
    feature_tensor = torch.from_numpy(reshape_feature).float() 
    label_tensor = torch.from_numpy(reshape_label).float() 
    power_data = TensorDataset(feature_tensor, label_tensor)
    
    return power_data

batch_size = 32
epochs = 20000
scale = 1000
model_name = "power4x4.pt"

device = "cuda:2"
model =ANG_P().to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
print(next(model.parameters()).device)
dataset = Power_dataset("./data/power_ap_mnist_4x4/feature.npy", "./data/power_ap_mnist_4x4/label.npy", scale)
train_dataset, val_dataset, test_dataset = random_split(dataset, [int(0.8 * len(dataset)), int(0.1 * len(dataset)), len(dataset) - int(0.8 * len(dataset))-int(0.1 * len(dataset))])

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

train_1(model, train_loader, optimizer, criterion, epochs, len(train_dataset), device, model_name)
model.load_state_dict(torch.load(model_name))
# frozen
layers_to_freeze = [model.conv1,  model.block1, model.block2, model.block3, model.block4, model.block5, model.block6, model.conv2, model.conv3,  model.output_1]

for layer in layers_to_freeze:
    for param in layer.parameters():
        param.requires_grad = False
for name, param in model.named_parameters():
    print(f"{name}: requires_grad = {param.requires_grad}")
    
optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=0.001)
train_2(model, train_loader, optimizer, criterion, epochs, len(train_dataset), device, model_name)

model.load_state_dict(torch.load(model_name))
# test(model, train_loader, criterion, len(train_dataset), device)

cuda:2
[4.583346e+03 4.579100e+03 4.246300e+00]
epoch1:mse1:15950827.4
epoch2:mse1:6985001.95
epoch3:mse1:4916501.075
epoch4:mse1:1129455.6239583334
epoch5:mse1:304976.24296875
epoch6:mse1:106406.31588541667
epoch7:mse1:56937.30963541667
epoch8:mse1:47191.024153645834
epoch9:mse1:37943.019270833334
epoch10:mse1:33062.557421875
epoch11:mse1:27978.28896484375
epoch12:mse1:35509.46129557292
epoch13:mse1:32416.863736979165
epoch14:mse1:44590.02347005208
epoch15:mse1:48085.886979166666
epoch16:mse1:22985.893294270834
epoch17:mse1:23704.954296875
epoch18:mse1:40021.6921875
epoch19:mse1:21050.428450520834
epoch20:mse1:22248.092838541666
epoch21:mse1:21177.42459309896
epoch22:mse1:19211.01845703125
epoch23:mse1:26855.759244791665
epoch24:mse1:20147.991569010417
epoch25:mse1:25621.78253580729
epoch26:mse1:27057.85537109375
epoch27:mse1:19998.700325520833
epoch28:mse1:22175.426009114584
epoch29:mse1:18291.494075520834
epoch30:mse1:19054.640983072917
epoch31:mse1:24385.6927734375
epoch32:mse1:193

  model.load_state_dict(torch.load(model_name))


epoch6:mse2:444.34195404052736
epoch7:mse2:27.385115917523702
epoch8:mse2:0.931146714091301
epoch9:mse2:0.23968069441616535
epoch10:mse2:0.2372934168825547
epoch11:mse2:0.23638996022442976
epoch12:mse2:0.23662822525948285
epoch13:mse2:0.23647548531492552
epoch14:mse2:0.2364136755466461
epoch15:mse2:0.2365998487919569
epoch16:mse2:0.23644321554650863
epoch17:mse2:0.23661085739731788
epoch18:mse2:0.23642011489719153
epoch19:mse2:0.23648373670876027
epoch20:mse2:0.23668239035954078
epoch21:mse2:0.2363412968814373
epoch22:mse2:0.23648520112037658
epoch23:mse2:0.23664091502626736
epoch24:mse2:0.23649720791727305
epoch25:mse2:0.23646894296010335
epoch26:mse2:0.23651687192420165
epoch27:mse2:0.23650621958076953
epoch28:mse2:0.2365184597671032
epoch29:mse2:0.23645951201518375
epoch30:mse2:0.23647435170908768
epoch31:mse2:0.2370578975727161
epoch32:mse2:0.23668503736456234
epoch33:mse2:0.23652072300513585
epoch34:mse2:0.23660738977293175
epoch35:mse2:0.23649401403963566
epoch36:mse2:0.236642635

  model.load_state_dict(torch.load(model_name))


<All keys matched successfully>