# Exp Setup

In [None]:
import numpy as np
import torch
import torch.nn as nn
from torch import optim
from torch.optim import lr_scheduler

import os, sys
import time

import warnings
import numpy as np
from matplotlib import pyplot as plt

sys.path.append('path/to/PatchTST/PatchTST_supervised')
os.chdir('path/to/PatchTST/PatchTST_supervised')

from utils.tools import EarlyStopping, adjust_learning_rate, visual, test_params_flop
from utils.metrics import metric
from data_provider.data_factory import data_provider
from models import Informer, Autoformer, Transformer, DLinear, Linear, NLinear, PatchTST
from models.PatchTST import Model

In [None]:
class Args:
    def __init__(self):
        # load parameters
        self.enc_in = 1
        self.seq_len = 96
        self.label_len = 0
        self.pred_len = 96
        self.e_layers = 6
        self.n_heads = 16
        self.d_model = 128
        self.d_ff = 256
        self.dropout = 0.2
        self.fc_dropout = 0.2
        self.head_dropout = 0.0
        self.individual = False
        self.patch_len = 16
        self.stride = 8
        self.padding_patch = 'end'
        self.revin = True
        self.affine = False
        self.subtract_last = False
        self.decomposition = False
        self.kernel_size = 25

        self.batch_size = 128
        self.data = 'ETTh1'
        self.embed = 'timeF'
        self.freq = 'h'
        self.root_path = 'path/to/PatchTST/PatchTST_supervised/dataset'
        self.data_path = 'ETTh1.csv'
        self.features = 'M'
        self.target = 'OT'
        self.num_workers = 10

In [None]:
dataset_configs = {
    'ETTm2': {'data':'ETTm2', 'data_path': 'path/to/PatchTST/PatchTST_supervised/dataset/'+'/ETT-small/ETTm2.csv', 'enc_in': 7},
    'Electricity': {'data':'custom', 'data_path': 'path/to/PatchTST/PatchTST_supervised/dataset/'+'/electricity/electricity.csv', 'enc_in': 321},
    'weather': {'data':'custom', 'data_path': 'path/to/PatchTST/PatchTST_supervised/dataset/'+'/weather/weather.csv', 'enc_in': 21},
    'traffic': {'data':'custom', 'data_path': 'path/to/PatchTST/PatchTST_supervised/dataset/'+'/traffic/traffic.csv', 'enc_in': 862},
    'Dataset_Patch_dependent': {'data':'Dataset_Patch_dependent', 'data_path': 'path/to/PatchTST/PatchTST_supervised/dataset/'+'/traffic/traffic.csv', 'enc_in': 1},
}

In [None]:
args = Args()

args.seq_len = 336
args.stride = 16
args.e_layers = 3
dataset = 'ETTm2'
model_path = 'path/to/checkpoint.pth'

for attr in dataset_configs[dataset].keys():
    value = dataset_configs[dataset][attr]
    setattr(args, attr, value)

train_data, train_loader = data_provider(args, flag='train')
vali_data, vali_loader = data_provider(args, flag='val')
test_data, test_loader = data_provider(args, flag='test')

model = Model(args).to('cuda')
model = nn.DataParallel(model)
model.load_state_dict(torch.load(model_path))
model = model.module
model.eval()

In [None]:
from tqdm import tqdm
device = 'cuda'
criterion = nn.MSELoss().to(device)
total_loss = []
with torch.no_grad():
    for i, (batch_x, batch_y, batch_x_mark, batch_y_mark) in enumerate(tqdm(test_loader)):
    # for i, (batch_x, batch_y, batch_x_mark, batch_y_mark) in enumerate(tqdm(train_loader)):
        batch_x = batch_x.float().to(device)
        batch_y = batch_y.float()

        batch_x_mark = batch_x_mark.float().to(device)
        batch_y_mark = batch_y_mark.float().to(device)

        outputs = model(batch_x)
        outputs = outputs[:, -args.pred_len:, 0:]
        batch_y = batch_y[:, -args.pred_len:, 0:].to(device)

        pred = outputs.detach().cpu()
        true = batch_y.detach().cpu()

        loss = criterion(pred, true)

        total_loss.append(loss)
total_loss = np.average(total_loss)
print('loss:', total_loss)

# Attention Perturbation Exp

In [None]:
def delete_hooks():
    global hooks
    try:
        for hook in hooks:
            hook.remove()
    except:
        pass

def set_atten_hook(module, input, output):
    global attenuation # 0.0 ~ 1.0
    # # ZERO
    # output = torch.zeros_like(output)
    # # MEAN
    # output = torch.zeros_like(output) + torch.mean(output, dim=-2, keepdim=True)
    # # EYE
    # eye = torch.eye(output.shape[2], device=output.device)
    # output = eye.unsqueeze(0).unsqueeze(1).repeat(output.shape[0], output.shape[1], 1, 1)
    return output * attenuation + torch.mean(output, dim=-1, keepdim=True) * ( 1 - attenuation )
    # return output * attenuation

def noise_score_hook(module, input):
    global noise_score
    noise = torch.randn_like(input[0]) * torch.std(input[0], dim=-1, keepdim=True) * noise_score
    return (input[0] + noise,)

In [None]:
zero_atten_blocks = [0, 1, 2, 3, 4, 5]

delete_hooks()
hooks = []
block_idx = 0
for layer in model.model.backbone.encoder.layers:
    if block_idx in zero_atten_blocks:
        hooks.append(layer.self_attn.sdp_attn.attn_dropout.register_forward_hook(set_atten_hook))
        hooks.append(layer.self_attn.sdp_attn.softmax_layer.register_forward_pre_hook(noise_score_hook))
    block_idx += 1

attenuations = [i/10 for i in range(0, 11, 1)] # 0.0 ~ 1.0
noise_scores = [i/10*2 for i in range(0, 11, 1)] # 0.0 ~ 2.0

losses = np.zeros((len(attenuations), len(noise_scores)))

for i_attenuation, attenuation in enumerate(attenuations):
    for j_noise_score, noise_score in enumerate(noise_scores):
        total_loss = []
        with torch.no_grad():
            for i, (batch_x, batch_y, batch_x_mark, batch_y_mark) in enumerate(test_loader):
                batch_x = batch_x.float().to(device)
                batch_y = batch_y.float()

                batch_x_mark = batch_x_mark.float().to(device)
                batch_y_mark = batch_y_mark.float().to(device)

                outputs = model(batch_x)
                outputs = outputs[:, -args.pred_len:, 0:]
                batch_y = batch_y[:, -args.pred_len:, 0:].to(device)

                pred = outputs.detach().cpu()
                true = batch_y.detach().cpu()

                loss = criterion(pred, true)

                total_loss.append(loss)
        total_loss = np.average(total_loss)
        print('attenuation:', attenuation, 'noise_score:', noise_score, 'loss:', total_loss)
        losses[i_attenuation, j_noise_score] = total_loss

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
from mpl_toolkits.mplot3d import Axes3D

fig = plt.figure(figsize=(10, 10))
ax = fig.add_subplot(111, projection='3d')
X, Y = np.meshgrid(noise_scores, attenuations)
ax.plot_surface(X, Y, losses, rstride=1, cstride=1, cmap='rainbow')
ax.set_xlabel('noise_score')
ax.set_ylabel('attenuation')
ax.set_zlabel('loss')
ax.set_zlim(0.16, 0.18)
ax.invert_yaxis()
ax.invert_xaxis()

points_to_annotate = [(0,0), (0,len(noise_scores)-1), (len(attenuations)-1,0), (len(attenuations)-1,len(noise_scores)-1)]
for point in points_to_annotate:
    x, y = point
    ax.text(noise_scores[y], attenuations[x], losses[x,y], f'{losses[x,y]:.3f}', color='black')


plt.show()

# FFN Perturbation Exp

In [None]:
def delete_hooks():
    global hooks
    try:
        for hook in hooks:
            hook.remove()
    except:
        pass

def set_ffn_hook(module, input, output):
    global attenuation # 0.0 ~ 1.0
    return output * attenuation + torch.mean(output, dim=-1, keepdim=True) * ( 1 - attenuation )

def noise_score_hook(module, input):
    global noise_score
    noise = torch.randn_like(input[0]) * torch.std(input[0], dim=-1, keepdim=True) * noise_score
    return (input[0] + noise,)

In [None]:
zero_ffn_blocks = [0, 1, 2, 3, 4, 5]

delete_hooks()
hooks = []
block_idx = 0
for layer in model.model.backbone.encoder.layers:
    if block_idx in zero_ffn_blocks:
        hooks.append(layer.ff.register_forward_hook(set_ffn_hook))
        hooks.append(layer.ff[1].register_forward_pre_hook(noise_score_hook))
    block_idx += 1

attenuations = [i/10 for i in range(0, 11, 1)] # 0.0 ~ 1.0
noise_scores = [i/10*2 for i in range(0, 11, 1)] # 0.0 ~ 2.0

losses = np.zeros((len(attenuations), len(noise_scores)))

for i_attenuation, attenuation in enumerate(attenuations):
    for j_noise_score, noise_score in enumerate(noise_scores):
        total_loss = []
        with torch.no_grad():
            for i, (batch_x, batch_y, batch_x_mark, batch_y_mark) in enumerate(test_loader):
                batch_x = batch_x.float().to(device)
                batch_y = batch_y.float()

                batch_x_mark = batch_x_mark.float().to(device)
                batch_y_mark = batch_y_mark.float().to(device)

                outputs = model(batch_x)
                outputs = outputs[:, -args.pred_len:, 0:]
                batch_y = batch_y[:, -args.pred_len:, 0:].to(device)

                pred = outputs.detach().cpu()
                true = batch_y.detach().cpu()

                loss = criterion(pred, true)

                total_loss.append(loss)
        total_loss = np.average(total_loss)
        print('attenuation:', attenuation, 'noise_score:', noise_score, 'loss:', total_loss)
        losses[i_attenuation, j_noise_score] = total_loss

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
from mpl_toolkits.mplot3d import Axes3D

fig = plt.figure(figsize=(10, 10))
ax = fig.add_subplot(111, projection='3d')
X, Y = np.meshgrid(noise_scores, attenuations)
ax.plot_surface(X, Y, losses, rstride=1, cstride=1, cmap='rainbow')
ax.set_xlabel('noise_score')
ax.set_ylabel('attenuation')
ax.set_zlabel('loss')
ax.set_zlim(0.16, 0.18)
ax.invert_yaxis()
ax.invert_xaxis()

points_to_annotate = [(0,0), (0,len(noise_scores)-1), (len(attenuations)-1,0), (len(attenuations)-1,len(noise_scores)-1)]
for point in points_to_annotate:
    x, y = point
    ax.text(noise_scores[y], attenuations[x], losses[x,y], f'{losses[x,y]:.3f}', color='black')

plt.show()