In [None]:
'''
Train an S4 model on sequential CIFAR10 / sequential MNIST with PyTorch for demonstration purposes.
This code borrows heavily from https://github.com/kuangliu/pytorch-cifar.

This file only depends on the standalone S4 layer
available in /models/s4/

* Train standard sequential CIFAR:
    python -m example
* Train sequential CIFAR grayscale:
    python -m example --grayscale
* Train MNIST:
    python -m example --dataset mnist --d_model 256 --weight_decay 0.0

The `S4Model` class defined in this file provides a simple backbone to train S4 models.
This backbone is a good starting point for many problems, although some tasks (especially generation)
may require using other backbones.

The default CIFAR10 model trained by this file should get
89+% accuracy on the CIFAR10 test set in 80 epochs.

Each epoch takes approximately 7m20s on a T4 GPU (will be much faster on V100 / A100).
'''



parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Training')
# Optimizer
parser.add_argument('--lr', default=0.01, type=float, help='Learning rate')
parser.add_argument('--weight_decay', default=0.01, type=float, help='Weight decay')
# Scheduler
# parser.add_argument('--patience', default=10, type=float, help='Patience for learning rate scheduler')
parser.add_argument('--epochs', default=100, type=float, help='Training epochs')
# Dataset
parser.add_argument('--dataset', default='cifar10', choices=['mnist', 'cifar10'], type=str, help='Dataset')
parser.add_argument('--grayscale', action='store_true', help='Use grayscale CIFAR10')
# Dataloader
parser.add_argument('--num_workers', default=4, type=int, help='Number of workers to use for dataloader')
parser.add_argument('--batch_size', default=64, type=int, help='Batch size')
# Model
parser.add_argument('--n_layers', default=4, type=int, help='Number of layers')
parser.add_argument('--d_model', default=128, type=int, help='Model dimension')
parser.add_argument('--dropout', default=0.1, type=float, help='Dropout')
parser.add_argument('--prenorm', action='store_true', help='Prenorm')
# General
parser.add_argument('--resume', '-r', action='store_true', help='Resume from checkpoint')

args = parser.parse_args()


best_acc = 0  # best test accuracy
start_epoch = 0  # start from epoch 0 or last checkpoint epoch

# Data
print(f'==> Preparing {args.dataset} data..')

def split_train_val(train, val_split):
    train_len = int(len(train) * (1.0-val_split))
    train, val = torch.utils.data.random_split(
        train,
        (train_len, len(train) - train_len),
        generator=torch.Generator().manual_seed(42),
    )
    return train, val

if args.dataset == 'cifar10':

    if args.grayscale:
        transform = transforms.Compose([
            transforms.Grayscale(),
            transforms.ToTensor(),
            transforms.Normalize(mean=122.6 / 255.0, std=61.0 / 255.0),
            transforms.Lambda(lambda x: x.view(1, 1024).t())
        ])
    else:
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
            transforms.Lambda(lambda x: x.view(3, 1024).t())
        ])

    # S4 is trained on sequences with no data augmentation!
    transform_train = transform_test = transform

    trainset = torchvision.datasets.CIFAR10(
        root='./data/cifar/', train=True, download=True, transform=transform_train)
    trainset, _ = split_train_val(trainset, val_split=0.1)

    valset = torchvision.datasets.CIFAR10(
        root='./data/cifar/', train=True, download=True, transform=transform_test)
    _, valset = split_train_val(valset, val_split=0.1)

    testset = torchvision.datasets.CIFAR10(
        root='./data/cifar/', train=False, download=True, transform=transform_test)

    d_input = 3 if not args.grayscale else 1
    d_output = 10

elif args.dataset == 'mnist':

    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.view(1, 784).t())
    ])
    transform_train = transform_test = transform

    trainset = torchvision.datasets.MNIST(
        root='./data', train=True, download=True, transform=transform_train)
    trainset, _ = split_train_val(trainset, val_split=0.1)

    valset = torchvision.datasets.MNIST(
        root='./data', train=True, download=True, transform=transform_test)
    _, valset = split_train_val(valset, val_split=0.1)

    testset = torchvision.datasets.MNIST(
        root='./data', train=False, download=True, transform=transform_test)

    d_input = 1
    d_output = 10
else: raise NotImplementedError

# Dataloaders
trainloader = torch.utils.data.DataLoader(
    trainset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers)
valloader = torch.utils.data.DataLoader(
    valset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers)
testloader = torch.utils.data.DataLoader(
    testset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers)



# Model


if args.resume:
    # Load checkpoint.
    print('==> Resuming from checkpoint..')
    assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!'
    checkpoint = torch.load('./checkpoint/ckpt.pth')
    model.load_state_dict(checkpoint['model'])
    best_acc = checkpoint['acc']
    start_epoch = checkpoint['epoch']




In [None]:
class S4Model(nn.Module):

    def __init__(
        self,
        d_input,
        d_output=10,
        d_model=256,
        n_layers=4,
        dropout=0.2,
        prenorm=False,
    ):
        super().__init__()

        self.prenorm = prenorm

        # Linear encoder (d_input = 1 for grayscale and 3 for RGB)
        self.encoder = nn.Linear(d_input, d_model)

        # Stack S4 layers as residual blocks
        self.s4_layers = nn.ModuleList()
        self.norms = nn.ModuleList()
        self.dropouts = nn.ModuleList()
        for _ in range(n_layers):
            self.s4_layers.append(
                S4D(d_model, dropout=dropout, transposed=True, lr=min(0.001, args.lr))
            )
            self.norms.append(nn.LayerNorm(d_model))
            self.dropouts.append(dropout_fn(dropout))

        # Linear decoder
        self.decoder = nn.Linear(d_model, d_output)

    def forward(self, x):
        """
        Input x is shape (B, L, d_input)
        """
        x = self.encoder(x)  # (B, L, d_input) -> (B, L, d_model)

        x = x.transpose(-1, -2)  # (B, L, d_model) -> (B, d_model, L)
        for layer, norm, dropout in zip(self.s4_layers, self.norms, self.dropouts):
            # Each iteration of this loop will map (B, d_model, L) -> (B, d_model, L)

            z = x
            if self.prenorm:
                # Prenorm
                z = norm(z.transpose(-1, -2)).transpose(-1, -2)

            # Apply S4 block: we ignore the state input and output
            z, _ = layer(z)

            # Dropout on the output of the S4 block
            z = dropout(z)

            # Residual connection
            x = z + x

            if not self.prenorm:
                # Postnorm
                x = norm(x.transpose(-1, -2)).transpose(-1, -2)

        x = x.transpose(-1, -2)

        # Pooling: average pooling over the sequence length
        x = x.mean(dim=1)

        # Decode the outputs
        x = self.decoder(x)  # (B, d_model) -> (B, d_output)

        return x

In [8]:
def setup_optimizer(model, lr, weight_decay, epochs):
    """
    S4 requires a specific optimizer setup.

    The S4 layer (A, B, C, dt) parameters typically
    require a smaller learning rate (typically 0.001), with no weight decay.

    The rest of the model can be trained with a higher learning rate (e.g. 0.004, 0.01)
    and weight decay (if desired).
    """

    # All parameters in the model
    all_parameters = list(model.parameters())

    # General parameters don't contain the special _optim key
    params = [p for p in all_parameters if not hasattr(p, "_optim")]

    # Create an optimizer with the general parameters
    optimizer = optim.AdamW(params, lr=lr, weight_decay=weight_decay)

    # Add parameters with special hyperparameters
    hps = [getattr(p, "_optim") for p in all_parameters if hasattr(p, "_optim")]
    hps = [
        dict(s) for s in sorted(list(dict.fromkeys(frozenset(hp.items()) for hp in hps)))
    ]  # Unique dicts
    for hp in hps:
        params = [p for p in all_parameters if getattr(p, "_optim", None) == hp]
        optimizer.add_param_group(
            {"params": params, **hp}
        )

    # Create a lr scheduler
    # scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=patience, factor=0.2)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs)

    # Print optimizer info
    keys = sorted(set([k for hp in hps for k in hp.keys()]))
    for i, g in enumerate(optimizer.param_groups):
        group_hps = {k: g.get(k, None) for k in keys}
        print(' | '.join([
            f"Optimizer group {i}",
            f"{len(g['params'])} tensors",
        ] + [f"{k} {v}" for k, v in group_hps.items()]))

    return optimizer, scheduler



In [None]:

###############################################################################
# Everything after this point is standard PyTorch training!
###############################################################################

# Training
def train():
    model.train()
    train_loss = 0
    correct = 0
    total = 0
    pbar = tqdm(enumerate(trainloader))
    for batch_idx, (inputs, targets) in pbar:
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

        pbar.set_description(
            'Batch Idx: (%d/%d) | Loss: %.3f | Acc: %.3f%% (%d/%d)' %
            (batch_idx, len(trainloader), train_loss/(batch_idx+1), 100.*correct/total, correct, total)
        )


def eval(epoch, dataloader, checkpoint=False):
    global best_acc
    model.eval()
    eval_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        pbar = tqdm(enumerate(dataloader))
        for batch_idx, (inputs, targets) in pbar:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)

            eval_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

            pbar.set_description(
                'Batch Idx: (%d/%d) | Loss: %.3f | Acc: %.3f%% (%d/%d)' %
                (batch_idx, len(dataloader), eval_loss/(batch_idx+1), 100.*correct/total, correct, total)
            )

    # Save checkpoint.
    if checkpoint:
        acc = 100.*correct/total
        if acc > best_acc:
            state = {
                'model': model.state_dict(),
                'acc': acc,
                'epoch': epoch,
            }
            if not os.path.isdir('checkpoint'):
                os.mkdir('checkpoint')
            torch.save(state, './checkpoint/ckpt.pth')
            best_acc = acc

        return acc

pbar = tqdm(range(start_epoch, args.epochs))
for epoch in pbar:
    if epoch == 0:
        pbar.set_description('Epoch: %d' % (epoch))
    else:
        pbar.set_description('Epoch: %d | Val acc: %1.3f' % (epoch, val_acc))
    train()
    val_acc = eval(epoch, valloader, checkpoint=True)
    eval(epoch, testloader)
    scheduler.step()
    # print(f"Epoch {epoch} learning rate: {scheduler.get_last_lr()}")

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.backends.cudnn as cudnn

import torchvision
import torchvision.transforms as transforms

import os
import argparse

from s4 import S4Block as S4  # Can use full version instead of minimal S4D standalone below

from tqdm.auto import tqdm

# Dropout broke in PyTorch 1.11
if tuple(map(int, torch.__version__.split('.')[:2])) == (1, 11):
    print("WARNING: Dropout is bugged in PyTorch 1.11. Results may be worse.")
    dropout_fn = nn.Dropout
if tuple(map(int, torch.__version__.split('.')[:2])) >= (1, 12):
    dropout_fn = nn.Dropout1d
else:
    dropout_fn = nn.Dropout2d


In [3]:
import torch
from torch import nn, Tensor
import math
from s4 import S4Block as S4 
from torch.utils.data import DataLoader, random_split, Dataset
import numpy as np
import random
import torch.nn.functional as F


class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=500):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)
        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, d_model)
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x: Tensor) -> Tensor:
        """
        Args:
            x: Tensor, shape [batch_size, seq_len, feature_num]
        """
        x = x + self.pe[:x.size(1)].unsqueeze(0)
        return self.dropout(x)
    


class S4ModelForRUL(nn.Module):
    def __init__(self, d_input, d_model=512, n_layers=4, dropout=0.1, max_len=500):
        super().__init__()
        self.pos_encoder = PositionalEncoding(d_model, dropout=dropout, max_len=max_len)
        self.encoder = nn.Linear(d_input, d_model)
        self.bn_encoder = nn.BatchNorm1d(max_len) 
        self.s4_layers = nn.ModuleList()
        for _ in range(n_layers):
            self.s4_layers.append(S4(d_model, dropout=dropout, transposed=True))
        self.decoder = nn.Linear(d_model, 1)
        self.dropout = nn.Dropout(dropout)

    def forward(self, src):
       
        src = self.encoder(src)  # [batch_size, seq_len, d_input] -> [batch_size, seq_len, d_model]
          # Apply tanh activation function
        src = self.bn_encoder(src)  
        src = F.tanh(src)
        src = self.pos_encoder(src)
        src = src.transpose(1, 2)  # S4 expects [batch_size, d_model, seq_len]
        for layer in self.s4_layers:
            src, _ = layer(src)  # We ignore the state output here
        src = src.transpose(1, 2)  # Back to [batch_size, seq_len, d_model]
        src = self.dropout(src)
        output = self.decoder(src) 
        return output




In [4]:

class load_data(Dataset):
    """
    root = new | old
    """
    def __init__(self, name, seq_len, root='new') -> None:
        super().__init__()
        data_root = "data/units/"
        if root == 'old':
            label_root = "data/labels/"
        elif root == 'new':
            label_root = "data/new_labels/"
        else:
            raise RuntimeError("got invalid parameter root='{}'".format(root))
        raw = np.loadtxt(data_root+name)[:,2:]
        lbl = np.loadtxt(label_root+name)/Rc
        l = len(lbl)
        if l<seq_len:
            raise RuntimeError("seq_len {} is too big for file '{}' with length {}".format(seq_len, name, l))
        raw, lbl = torch.tensor(raw, dtype=torch.float), torch.tensor(lbl, dtype=torch.float)
        lbl_pad_0 = [torch.ones([seq_len-i-1]) for i in range(seq_len-1)] 
        data_pad_0 = [torch.zeros([seq_len-i-1,24]) for i in range(seq_len-1)]
        lbl_pad_1 = [torch.zeros([i+1]) for i in range(seq_len-1)] 
        data_pad_1 = [torch.zeros([i+1,24]) for i in range(seq_len-1)]
        self.data = [torch.cat([data_pad_0[i],raw[:i+1]],0) for i in range(seq_len-1)] 
        self.data += [raw[i-seq_len+1:i+1] for i in range(seq_len-1, l)]
        self.data += [torch.cat([raw[l-seq_len+i+1:], data_pad_1[i]],0) for i in range(seq_len-1)]
        self.label = [torch.cat([lbl_pad_0[i],lbl[:i+1]],0) for i in range(seq_len-1)] 
        self.label += [lbl[i-seq_len+1:i+1] for i in range(seq_len-1, l)]
        self.label += [torch.cat([lbl[l-seq_len+i+1:], lbl_pad_1[i]],0) for i in range(seq_len-1)]
        self.padding = [torch.cat([torch.ones(seq_len-i-1), torch.zeros(i+1)],0) for i in range(seq_len-1)]   # 1 for ingore
        self.padding += [torch.zeros(seq_len) for i in range(seq_len-1, l)]
        self.padding += [torch.cat([torch.zeros(seq_len-i-1), torch.ones(i+1)],0) for i in range(seq_len-1)]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        return self.data[index], self.label[index], self.padding[index]


class load_all_data(Dataset):
    """
    root: new | old
    name: LIST of txt files to collect 
    """
    def __init__(self, name, seq_len) -> None:
        super().__init__()
        data_root = "data/units/"
        label_root = "data/new_labels/"
        lis = os.listdir(data_root)
        data_list = [i for i in lis if i in name]
        self.data, self.label, self.padding = [], [], []
        for n in data_list:
            raw = np.loadtxt(data_root+n)[:,2:]
            lbl = np.loadtxt(label_root+n)/Rc
            l = len(lbl)
            if l<seq_len:
                raise RuntimeError("seq_len {} is too big for file '{}' with length {}".format(seq_len, n, l))
            raw, lbl = torch.tensor(raw, dtype=torch.float), torch.tensor(lbl, dtype=torch.float)
            lbl_pad_0 = [torch.ones([seq_len-i-1]) for i in range(seq_len-1)] 
            data_pad_0 = [torch.zeros([seq_len-i-1,24]) for i in range(seq_len-1)]
            lbl_pad_1 = [torch.zeros([i+1]) for i in range(seq_len-1)] 
            data_pad_1 = [torch.zeros([i+1,24]) for i in range(seq_len-1)]
            self.data += [torch.cat([data_pad_0[i],raw[:i+1]],0) for i in range(seq_len-1)] 
            self.data += [raw[i-seq_len+1:i+1] for i in range(seq_len-1, l)]
            self.data += [torch.cat([raw[l-seq_len+i+1:], data_pad_1[i]],0) for i in range(seq_len-1)]
            self.label += [torch.cat([lbl_pad_0[i],lbl[:i+1]],0) for i in range(seq_len-1)] 
            self.label += [lbl[i-seq_len+1:i+1] for i in range(seq_len-1, l)]
            self.label += [torch.cat([lbl[l-seq_len+i+1:], lbl_pad_1[i]],0) for i in range(seq_len-1)]
            self.padding += [torch.cat([torch.ones(seq_len-i-1), torch.zeros(i+1)],0) for i in range(seq_len-1)]   # 1 for ingore
            self.padding += [torch.zeros(seq_len) for i in range(seq_len-1, l)]
            self.padding += [torch.cat([torch.zeros(seq_len-i-1), torch.ones(i+1)],0) for i in range(seq_len-1)]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        return self.data[index], self.label[index], self.padding[index]

In [5]:
name = 'FD001'

In [6]:
tr = np.loadtxt("save/"+name+"/train"+name+".txt", dtype=str).tolist()
val = np.loadtxt("save/"+name+"/valid"+name+".txt", dtype=str).tolist()
ts = np.loadtxt("save/"+name+"/test"+name+".txt", dtype=str).tolist()

target = ts+val


In [45]:
device = 'cuda:2' if torch.cuda.is_available() else 'cpu'
d_input = 24  
seq_len = 70
Rc = 130
model = S4ModelForRUL(d_input=d_input, d_model=512, n_layers=1, dropout=0.1, max_len=seq_len)
model.to(device)
Loss = nn.MSELoss()
Loss.to(device)

# opt = torch.optim.Adam(model.parameters(), lr=0.01)
# sch = torch.optim.lr_scheduler.StepLR(opt, 50, 0.5)
epochs = 100
opt, sch = setup_optimizer(
    model, lr=0.02, weight_decay=1e-4, epochs=epochs
)


Optimizer group 0 | 9 tensors | weight_decay 0.0001
Optimizer group 1 | 1 tensors | weight_decay 0.0001
Optimizer group 2 | 5 tensors | weight_decay 0.0


In [46]:
import torch
from torch.utils.data import DataLoader
import random
import math

def train(data, model, loss_function, optimizer, seq_len, epochs, device, name):
    min_rmse = float('inf')
    for e in range(epochs):
        model.train()
        random.shuffle(data)
        train_data = load_all_data(data, seq_len=seq_len)  # Ensure this returns a dataset compatible with DataLoader
        total_loss = 0.0
        train_loader = DataLoader(train_data, batch_size=128, shuffle=True)
       
        for train_data, train_label, train_padding in train_loader:
            
            train_data, train_label = train_data.to(device), train_label.to(device)
            optimizer.zero_grad()
            output = model(train_data).squeeze()  # Adjusted to pass only train_data
            
            
            loss = loss_function(output, train_label)
            total_loss += loss.item()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 3)
            optimizer.step()

        
       
        
        rmse = validate(model, seq_len, device, target)  # Adjust validate function call accordingly
        print(f"Epoch: {e}, Loss: {total_loss / len(train_loader)}, RMSE: {rmse}")
        
        if rmse < min_rmse:
            min_rmse = rmse
            torch.save(model.state_dict(), f'save/s4_{name}.pth')
        
        sch.step()
    
    return min_rmse

            
            
            
        
def validate(model, seq_len, device, val_data):
    model.eval()
    total_loss = 0.0
    with torch.no_grad():
        for i in val_data:
            pred_sum, pred_cnt = torch.zeros(800), torch.zeros(800)
            valid_data = load_data(i, seq_len)  # Ensure this returns a dataset compatible with DataLoader
            valid_loader = DataLoader(valid_data, batch_size=1000, shuffle=False)
            
           
            for valid_data, valid_label, valid_padding in valid_loader:
               
                valid_data = valid_data.to(device)
                data_len = len(valid_data)
                output = model(valid_data).squeeze(2).cpu()  # Adjusted to pass only valid_data
                # Proceed with your RMSE calculation

                
                
                for j in range(data_len):
                    if j < seq_len-1:
                    
                        pred_sum[:j+1] += output[j, -(j+1):]
                        pred_cnt[:j+1] += 1
                    elif j <= data_len-seq_len:
                        pred_sum[j-seq_len+1:j+1] += output[j]
                        pred_cnt[j-seq_len+1:j+1] += 1
                    else:
                        pred_sum[data_len-seq_len+1-(data_len-j):data_len-seq_len+1] += output[j, :(data_len-j)]
                        pred_cnt[data_len-seq_len+1-(data_len-j):data_len-seq_len+1] += 1
                truth = torch.tensor([valid_label[j,-1] for j in range(len(valid_label)-seq_len+1)], dtype=torch.float)
                pred_sum, pred_cnt = pred_sum[:data_len-seq_len+1], pred_cnt[:data_len-seq_len+1]
                pred = pred_sum/pred_cnt
                mse = float(torch.sum(torch.pow(pred-truth, 2)))
                rmse = math.sqrt(mse/data_len)
                total_loss += rmse
        return total_loss*Rc/len(val_data)
                
               

train(tr, model, Loss, opt, seq_len, epochs, device, name)

Epoch: 0, Loss: 14.088769018209565, RMSE: 13.985892764445618
Epoch: 1, Loss: 0.01889443943100805, RMSE: 11.91389800098081
Epoch: 2, Loss: 0.014130031960588452, RMSE: 12.433026148023858
Epoch: 3, Loss: 0.013327636051218252, RMSE: 12.71205144021184


KeyboardInterrupt: 

In [39]:
def score(pred, truth):
    """input must be tensors!"""
    x = pred-truth
    score1 = torch.tensor([torch.exp(-i/13)-1 for i in x if i<0])
    score2 = torch.tensor([torch.exp(i/10)-1 for i in x if i>=0])
    return int(torch.sum(score1)+torch.sum(score2))


def get_pred_result(data_len, out, lb):
    pred_sum, pred_cnt = torch.zeros(800), torch.zeros(800)
    for j in range(data_len):
        if j < seq_len-1:
            pred_sum[:j+1] += out[j, -(j+1):]
            pred_cnt[:j+1] += 1
        elif j <= data_len-seq_len:
            pred_sum[j-seq_len+1:j+1] += out[j]
            pred_cnt[j-seq_len+1:j+1] += 1
        else:
            pred_sum[data_len-seq_len+1-(data_len-j):data_len-seq_len+1] += out[j, :(data_len-j)]
            pred_cnt[data_len-seq_len+1-(data_len-j):data_len-seq_len+1] += 1
    truth = torch.tensor([lb[j,-1] for j in range(len(lb[0])-seq_len+1)], dtype=torch.float)
    
    pred_sum, pred_cnt = pred_sum[:data_len-seq_len+1], pred_cnt[:data_len-seq_len+1]
    pred2 = pred_sum/pred_cnt
    pred2 *= Rc
    truth *= Rc
    return truth, pred2 

In [42]:
import torch
from torch.utils.data import DataLoader

# Assuming you have a 'load_data' or 'load_all_data' class for handling test datasets
# and the 'score' function for scoring the predictions.

def load_model(model_path, device):
    model = S4ModelForRUL(d_input=24, d_model=512, n_layers=1, dropout=0.1, max_len=70)  # Adjust parameters as necessary
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.to(device)
    model.eval()
    return model

def test_model(model, test_data, seq_len, device):
    tot = 0.0
    tot_sc= 0
    with torch.no_grad():
        for _ in range(test_len):
            i = next(test_iter)
            
            test_dataset = load_data(i, seq_len=seq_len)  # Adjust depending on your data loader
            data_len = len(test_dataset)
            test_loader = DataLoader(test_dataset, batch_size=1000, shuffle=False)  # Batch size can be adjusted
            it = iter(test_loader)
            d = next(it)
            input, lb, msk = d[0], d[1], d[2]
            input = input.to(device)
            predictions = model(input).squeeze(2)
            
            truth, pred = get_pred_result(data_len,  predictions.to('cpu'), lb.to('cpu'))
            
            mse = float(torch.sum(torch.pow(pred-truth, 2)))
            rmse = math.sqrt(mse/data_len)
            tot += rmse
            sc = score(pred, truth)
            tot_sc += sc
            print("for file {}: rmse={:.4f}, score={}".format(i, rmse, sc))
            print('-'*80)
            
            
                
                
                
                
                
    avg_rmse = tot / len(test_data)
    avg_score = tot_sc / len(test_data)
    print("tested on [{}] files, mean RMSE = {:.4f}, mean score = {}".format(len(test_data), tot/len(test_data), int(tot_sc/len(test_data))))
    
    return avg_rmse, avg_score

device = 'cuda:2' if torch.cuda.is_available() else 'cpu'  # Adjust the device as needed
model_path = 'save/s4_FD001.pth'  # Adjust the path to your saved model

test_data = tr  # Your test set file names
test_len = len(ts)
test_iter = iter(ts)

model = load_model(model_path, device)
avg_rmse, avg_score = test_model(model, test_data, seq_len, device)




for file FD001-22.txt: rmse=54.3685, score=230763
--------------------------------------------------------------------------------
for file FD001-23.txt: rmse=50.4937, score=179187
--------------------------------------------------------------------------------
for file FD001-36.txt: rmse=51.2551, score=180114
--------------------------------------------------------------------------------
for file FD001-8.txt: rmse=52.1057, score=184408
--------------------------------------------------------------------------------
for file FD001-10.txt: rmse=45.8891, score=175925
--------------------------------------------------------------------------------
for file FD001-4.txt: rmse=48.6467, score=217472
--------------------------------------------------------------------------------
for file FD001-83.txt: rmse=44.1214, score=222225
--------------------------------------------------------------------------------
for file FD001-86.txt: rmse=49.0491, score=248496
-----------------------------------

In [120]:
x=torch.load(f'save/s4_{name}.pth', map_location='cuda:2')
model.load_state_dict(x)
data_root = "data/units/"
label_root = "data/labels/"
lis = os.listdir(data_root)
test_list = [i for i in lis if i[:5] == name]
random.shuffle(test_list)
test_len = len(test_list)
list_iter = iter(test_list)
test()
    

for file FD001-35.txt: rmse=58.6192, score=8900991
--------------------------------------------------------------------------------
for file FD001-71.txt: rmse=55.7260, score=8901015
--------------------------------------------------------------------------------
for file FD001-43.txt: rmse=55.8255, score=8901015
--------------------------------------------------------------------------------
for file FD001-24.txt: rmse=63.0116, score=8900959
--------------------------------------------------------------------------------
for file FD001-10.txt: rmse=54.3877, score=8901028
--------------------------------------------------------------------------------
for file FD001-73.txt: rmse=55.2365, score=8901019
--------------------------------------------------------------------------------
for file FD001-31.txt: rmse=53.3155, score=8901039
--------------------------------------------------------------------------------
for file FD001-61.txt: rmse=58.1615, score=8900995
-------------------------