In [124]:
import torch
import torch.nn as nn
import numpy as np
import time
import math
import pandas as pd
from matplotlib import pyplot

torch.manual_seed(0)
np.random.seed(0)

# S is the source sequence length
# T is the target sequence length
# N is the batch size
# E is the feature number

#src = torch.rand((10, 32, 512)) # (S,N,E) 
#tgt = torch.rand((20, 32, 512)) # (T,N,E)
#out = transformer_model(src, tgt)

input_window = 100 # number of input steps
output_window = 4 # number of prediction steps, in this model its fixed to one
block_len = input_window + output_window # for one input-output pair
batch_size = 32
train_size = 0.8
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class PositionalEncoding(nn.Module):

    def __init__(self, d_model, max_len=5000):
        super(PositionalEncoding, self).__init__()       
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        # div_term = torch.exp(
        #     torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
        # )
        div_term = 1 / (10000 ** ((2 * np.arange(d_model)) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term[0::2])
        pe[:, 1::2] = torch.cos(position * div_term[1::2])

        pe = pe.unsqueeze(0).transpose(0, 1) # [5000, 1, d_model],so need seq-len <= 5000
        #pe.requires_grad = False
        self.register_buffer('pe', pe)

    def forward(self, x):
        # print(self.pe[:x.size(0), :].repeat(1,x.shape[1],1).shape ,'---',x.shape)
        # dimension 1 maybe inequal batchsize
        return x + self.pe[:x.size(0), :].repeat(1,x.shape[1],1)
          

class TransAm(nn.Module):
    def __init__(self,feature_size=512,num_layers=2,dropout=0.1):
        super(TransAm, self).__init__()
        self.model_type = 'Transformer'
        self.input_embedding  = nn.Linear(3,feature_size)
        self.src_mask = None

        self.pos_encoder = PositionalEncoding(feature_size)
        self.encoder_layer = nn.TransformerEncoderLayer(d_model=feature_size, nhead=8, dropout=dropout)
        self.transformer_encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=num_layers)
        self.decoder = nn.Linear(feature_size,3)
        self.init_weights()

    def init_weights(self):
        initrange = 0.1    
        self.decoder.bias.data.zero_()
        self.decoder.weight.data.uniform_(-initrange, initrange)

    def forward(self,src):
        # src with shape (input_window, batch_len, 1)
        if self.src_mask is None or self.src_mask.size(0) != len(src):
            device = src.device
            mask = self._generate_square_subsequent_mask(len(src)).to(device)
            self.src_mask = mask

        src = self.input_embedding(src)
        src = self.pos_encoder(src)
        output = self.transformer_encoder(src,self.src_mask)#, self.src_mask)
        output = self.decoder(output)
        #print(output.shape)output
        # Additional transformation:
        output = output.permute(2, 1, 0)       # [3, 32, 10]
        layer = nn.Linear(10, 4).to(output.device)  # must be on the same device
        output = layer(output)                 # [3, 32, 4]
        output = output.permute(1, 0, 2)# [32, 3, 4]
        output = output.permute(0,2,1)
        return output

    def _generate_square_subsequent_mask(self, sz):
        mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        return mask

In [125]:
def create_inout_sequences(input_data, input_window, output_window):
    inout_seq = []
    L = len(input_data)

    for i in range(L - input_window - output_window + 1):
        input_seq = input_data[i : i + input_window]
        output_seq = input_data[i + input_window : i + input_window + output_window]
        inout_seq.append((
            torch.tensor(input_seq, dtype=torch.float32),
            torch.tensor(output_seq, dtype=torch.float32)
        ))
    print(len(inout_seq))
    

    return inout_seq  # ← Return as list of tuples



In [204]:
import pandas as pd
import torch
from sklearn.preprocessing import MinMaxScaler

def get_data(input_window, output_window, train_size, device):
    # Load the dataset
    df = pd.read_csv('interpolated_keff_1000_rows.csv')
    print(df.columns.tolist())
    # Select the required columns
    df.columns = df.columns.str.strip()  # Remove extra spaces from column names
    df = df[['Time [days]', 'keff_1', 'Unnamed: 3']]
    
    # Convert to NumPy array
    data = df.to_numpy()

    # Normalize all 3 columns to range [-1, 1]
    scaler = MinMaxScaler(feature_range=(-1, 1))
    data_normalized = scaler.fit_transform(data)

    # Split into train and test sets
    #num_samples = len(data_normalized)
    num_samples = len(data)
    train_samples = int(num_samples * train_size)
    #train_samples = int(num_samples*train_size)
    train_data = data_normalized[:train_samples]
    test_data = data_normalized[train_samples:]
    #train_data = data[:train_samples]
    #test_data = data[train_samples:]
    
    # Create input-output sequences
    train_sequence = create_inout_sequences(train_data, input_window, output_window)
    test_sequence = create_inout_sequences(test_data, input_window, output_window)
    train_sequence = [(x.to(device), y.to(device)) for x, y in train_sequence]
    test_sequence = [(x.to(device), y.to(device)) for x, y in test_sequence]



    return train_sequence,test_sequence



In [205]:
def get_batch(input_data, i, batch_size):
    batch_len = min(batch_size, len(input_data) - i)
    data = input_data[i : i + batch_len]

    # input and target will have shape: [batch_len, seq_len, num_features]
    input_batch = torch.stack([item[0] for item in data])
    target_batch = torch.stack([item[1] for item in data])

    # Transpose to match model expectation: [seq_len, batch_len, num_features]
    input_batch = input_batch.permute(1, 0, 2)
    target_batch = target_batch.permute(1, 0, 2)

    return input_batch, target_batch


In [206]:
def train(train_data):
    model.train()
    total_loss = 0.
    start_time = time.time()
    all_losses = []  # 🔸 Add this to store losses

    for batch, i in enumerate(range(0, len(train_data), batch_size)):
        data, targets = get_batch(train_data, i , batch_size)
        print('data: ',data.shape)
        print('target',targets.shape)
        targets = targets.permute(1,0,2)
        print('targets',targets.shape)
        optimizer.zero_grad()
        output = model(data)
        #print(output.shape)

        loss = criterion(output, targets)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.7)
        optimizer.step()

        batch_loss = loss.item()
        total_loss += batch_loss
        all_losses.append(batch_loss)  # 🔸 Save the batch loss

        log_interval = int(len(train_data) / batch_size / 5)
        if batch % log_interval == 0 and batch > 0:
            cur_loss = total_loss / log_interval
            elapsed = time.time() - start_time
            print('| epoch {:3d} | {:5d}/{:5d} batches | '
                  'lr {:02.6f} | {:5.2f} ms | '
                  'loss {:5.5f}'.format(
                    epoch, batch, len(train_data) // batch_size,
                    scheduler.get_last_lr()[0],
                    elapsed * 1000 / log_interval,
                    cur_loss))
            total_loss = 0
            start_time = time.time()

    return all_losses  # 🔸 Return the list of all losses


In [211]:
def plot_and_loss(eval_model, data_source, epoch):
    eval_model.eval() 
    total_loss = 0.
    test_result = torch.Tensor(0)    
    truth = torch.Tensor(0)
    with torch.no_grad():
        for i in range(len(data_source)):
            data, target = get_batch(data_source, i, 1)  # one-step forecast
            target = target.permute(1, 0, 2)
            output = eval_model(data)            
            total_loss += criterion(output, target).item()
            test_result = torch.cat((test_result, output[-1].reshape(-1).cpu()), 0)
            truth = torch.cat((truth, target[-1].reshape(-1).cpu()), 0)
    
    import matplotlib.pyplot as plt
    
    # Create figure with two subplots
    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(16, 10), sharex=True, 
                                   gridspec_kw={'height_ratios': [2.5, 1]})
    
    # --- Plot 1: Predictions vs. Actual Values ---
    time_steps = range(len(truth))
    
    # Fill area between prediction and actual for better visualization
    ax1.fill_between(time_steps, truth.numpy(), test_result.numpy(), 
                     alpha=0.3, color='lightblue', label='Prediction Gap')
    
    # Plot actual values with thicker line
    ax1.plot(truth.numpy(), label='Actual Values', color='#1f77b4', 
             linewidth=2.5, alpha=0.9)
    
    # Plot predictions with contrasting color and style
    ax1.plot(test_result.numpy(), label='Predictions', color='#ff7f0e', 
             linestyle='--', linewidth=2, alpha=0.8)
    
    # Styling improvements
    ax1.set_title(f'Model Predictions vs. Actual Values (Epoch {epoch})', 
                  fontsize=18, fontweight='bold', pad=20)
    ax1.set_ylabel('Value', fontsize=14, fontweight='semibold')
    ax1.legend(loc='upper left', fontsize=12, framealpha=0.9)
    ax1.grid(True, alpha=0.3, linestyle='-', linewidth=0.5)
    ax1.set_facecolor('#fafafa')
    
    # Add statistics text box
    mae = torch.mean(torch.abs(test_result - truth)).item()
    rmse = torch.sqrt(torch.mean((test_result - truth)**2)).item()
    stats_text = f'MAE: {mae:.2f}\nRMSE: {rmse:.2f}'
    ax1.text(0.02, 0.98, stats_text, transform=ax1.transAxes, 
             verticalalignment='top', bbox=dict(boxstyle='round', facecolor='white', alpha=0.8),
             fontsize=11)
    
    # --- Plot 2: Error Analysis ---
    error = test_result - truth
    
    # Color-coded error bars
    colors = ['red' if e < 0 else 'green' for e in error.numpy()]
    ax2.bar(time_steps, error.numpy(), color=colors, alpha=0.6, width=1.0)
    
    # Zero reference line
    ax2.axhline(y=0, color='black', linestyle='-', linewidth=1.5, alpha=0.7)
    
    # Moving average of error for trend visualization
    if len(error) > 50:
        window_size = min(50, len(error) // 10)
        error_ma = error.numpy()
        error_smooth = np.convolve(error_ma, np.ones(window_size)/window_size, mode='same')
        ax2.plot(time_steps, error_smooth, color='darkblue', linewidth=2, 
                label=f'Moving Avg (window={window_size})', alpha=0.8)
        ax2.legend(loc='upper right', fontsize=10)
    
    # Styling for error plot
    ax2.set_xlabel('Time Step', fontsize=14, fontweight='semibold')
    ax2.set_ylabel('Error\n(Pred - Actual)', fontsize=12, fontweight='semibold')
    ax2.grid(True, alpha=0.3, linestyle='-', linewidth=0.5)
    ax2.set_facecolor('#fafafa')
    
    # Add error statistics
    error_mean = torch.mean(error).item()
    error_std = torch.std(error).item()
    error_stats = f'Mean: {error_mean:.2f}\nStd: {error_std:.2f}'
    ax2.text(0.98, 0.02, error_stats, transform=ax2.transAxes, 
             verticalalignment='bottom', horizontalalignment='right',
             bbox=dict(boxstyle='round', facecolor='white', alpha=0.8),
             fontsize=10)
    
    # Improve overall appearance
    plt.tight_layout()
    plt.subplots_adjust(hspace=0.1)  # Reduce space between subplots
    
    # Save with better quality and format
    plt.savefig(f'graph/transformer-epoch-{epoch}.png', 
                dpi=300, bbox_inches='tight', facecolor='white')
    
    plt.close(fig)
    return total_loss / len(data_source)  # Fixed: use len() instead of i

In [212]:
def predict_future(eval_model, data_source,steps):
    eval_model.eval() 
    total_loss = 0.
    test_result = torch.Tensor(0)    
    truth = torch.Tensor(0)
    data, _ = get_batch(data_source , 0 , 1)
    with torch.no_grad():
        for i in range(0, steps):            
            output = eval_model(data[-input_window:])
            # (seq-len , batch-size , features-num)
            # input : [ m,m+1,...,m+n ] -> [m+1,...,m+n+1]
            data = torch.cat((data, output[-1:])) # [m,m+1,..., m+n+1]

    data = data.cpu().view(-1)
    
    # I used this plot to visualize if the model pics up any long therm structure within the data.
    pyplot.plot(data,color="red")       
    pyplot.plot(data[:input_window],color="blue")    
    pyplot.grid(True, which='both')
    pyplot.axhline(y=0, color='k')
    pyplot.savefig('graph/transformer-future%d.png'%steps)
    pyplot.show()
    pyplot.close()
        

In [213]:
def evaluate(eval_model, data_source):
    eval_model.eval() # Turn on the evaluation mode
    total_loss = 0.
    eval_batch_size = 32
    with torch.no_grad():
        # for i in range(0, len(data_source) - 1, eval_batch_size): # Now len-1 is not necessary
        for i in range(0, len(data_source), eval_batch_size):
            data, targets = get_batch(data_source, i,eval_batch_size)
            output = eval_model(data)            
            total_loss += len(data[0]) * criterion(output, targets).cpu().item()
    return total_loss / len(data_source)

In [214]:
train_data, val_data = get_data(input_window = 10,output_window = 4,train_size=0.6,device = 'cuda')
model = TransAm().to(device)

criterion = nn.MSELoss()
lr = 0.005 
#optimizer = torch.optim.SGD(model.parameters(), lr=lr)
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1, gamma=0.95)

best_val_loss = float("inf")
epochs = 100 # The number of epochs
best_model = None

for epoch in range(1, epochs + 1):
    epoch_start_time = time.time()
    train(train_data)
    if ( epoch % 5 == 0 ):
        val_loss = plot_and_loss(model, val_data,epoch)
        #predict_future(model, val_data,200)
    #else:
        #val_loss = evaluate(model, val_data)
   
    print('-' * 89)
    print('| end of epoch {:3d} | time: {:5.2f}s |'.format(epoch, (time.time() - epoch_start_time),
                                     ))
    print('-' * 89)
    #val_loss, math.exp(val_loss  valid loss {:5.5f} | valid ppl {:8.2f}


['Time [days]', 'keff_1', 'Step', 'Unnamed: 3']
587
387
data:  torch.Size([10, 32, 3])
target torch.Size([4, 32, 3])
targets torch.Size([32, 4, 3])
data:  torch.Size([10, 32, 3])
target torch.Size([4, 32, 3])
targets torch.Size([32, 4, 3])
data:  torch.Size([10, 32, 3])
target torch.Size([4, 32, 3])
targets torch.Size([32, 4, 3])
data:  torch.Size([10, 32, 3])
target torch.Size([4, 32, 3])
targets torch.Size([32, 4, 3])
| epoch   1 |     3/   18 batches | lr 0.005000 | 14.74 ms | loss 26.77579
data:  torch.Size([10, 32, 3])
target torch.Size([4, 32, 3])
targets torch.Size([32, 4, 3])
data:  torch.Size([10, 32, 3])
target torch.Size([4, 32, 3])
targets torch.Size([32, 4, 3])
data:  torch.Size([10, 32, 3])
target torch.Size([4, 32, 3])
targets torch.Size([32, 4, 3])
| epoch   1 |     6/   18 batches | lr 0.005000 | 10.62 ms | loss 10.30722
data:  torch.Size([10, 32, 3])
target torch.Size([4, 32, 3])
targets torch.Size([32, 4, 3])
data:  torch.Size([10, 32, 3])
target torch.Size([4, 32, 3

  pe[:, 0::2] = torch.sin(position * div_term[0::2])
  pe[:, 1::2] = torch.cos(position * div_term[1::2])


| epoch   1 |    12/   18 batches | lr 0.005000 | 10.94 ms | loss 1.27428
data:  torch.Size([10, 32, 3])
target torch.Size([4, 32, 3])
targets torch.Size([32, 4, 3])
data:  torch.Size([10, 32, 3])
target torch.Size([4, 32, 3])
targets torch.Size([32, 4, 3])
data:  torch.Size([10, 32, 3])
target torch.Size([4, 32, 3])
targets torch.Size([32, 4, 3])
| epoch   1 |    15/   18 batches | lr 0.005000 | 10.84 ms | loss 1.59503
data:  torch.Size([10, 32, 3])
target torch.Size([4, 32, 3])
targets torch.Size([32, 4, 3])
data:  torch.Size([10, 32, 3])
target torch.Size([4, 32, 3])
targets torch.Size([32, 4, 3])
data:  torch.Size([10, 11, 3])
target torch.Size([4, 11, 3])
targets torch.Size([11, 4, 3])
| epoch   1 |    18/   18 batches | lr 0.005000 |  9.63 ms | loss 0.67049
-----------------------------------------------------------------------------------------
| end of epoch   1 | time:  0.20s |
-----------------------------------------------------------------------------------------
data:  tor