In [1]:
from google.colab import drive
drive.mount('/content/drive')

%cd drive/MyDrive/ST/stargan

Mounted at /content/drive
/content/drive/MyDrive/ST/stargan


In [2]:
import pickle
import numpy as np
from sklearn.model_selection import train_test_split
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
from torch.optim.lr_scheduler import LambdaLR
import torch.optim as optim
import os
import csv

seed = 2710
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
np.random.seed(seed)

In [3]:
class TSTRClassifier(nn.Module):
    def __init__(self, num_timesteps=128, num_channels=3, num_classes=5):
        super(TSTRClassifier, self).__init__()

        self.conv1 = nn.Conv1d(num_channels, 16, kernel_size=5, stride=1, padding=2)
        self.bn1 = nn.BatchNorm1d(16)
        self.conv2 = nn.Conv1d(16, 32, kernel_size=5, stride=1, padding=2)
        self.bn2 = nn.BatchNorm1d(32)
        self.conv3 = nn.Conv1d(32, 64, kernel_size=5, stride=1, padding=2)
        self.bn3 = nn.BatchNorm1d(64)
        self.conv4 = nn.Conv1d(64, 128, kernel_size=5, stride=1, padding=2)
        self.bn4 = nn.BatchNorm1d(128)
        self.pool = nn.MaxPool1d(kernel_size=2, stride=2)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(p=0.25)

        self.fc_shared = nn.Linear(num_timesteps * 8, 100)

        self.fc_class = nn.Linear(100, num_classes)

    def forward(self, x):
        x = self.pool(self.relu(self.bn1(self.conv1(x))))
        x = self.pool(self.relu(self.bn2(self.conv2(x))))
        x = self.pool(self.relu(self.bn3(self.conv3(x))))
        x = self.pool(self.relu(self.bn4(self.conv4(x))))
        x = x.view(x.size(0), -1)  # Flatten
        x = self.dropout(x)
        x = self.relu(self.fc_shared(x))

        # Final output for class prediction
        class_outputs = self.fc_class(x)
        return class_outputs



def get_syn_data(dataset, src_class, domain, mode='ref'):

    if dataset == 'realworld':
        class_names = ['WAL', 'RUN', 'CLD', 'CLU']
        num_train_domains = 10
    elif dataset == 'cwru':
        class_names = ['IR', 'Ball', 'OR_centred', 'OR_orthogonal', 'OR_opposite']
        num_train_domains = 4

    # Load the dataset
    with open(f'data/{dataset}_syndata_{mode}_fs.pkl', 'rb') as f:
        syndata = pickle.load(f)

    x = []
    y = []
    k = []

    for i, trg_class in enumerate(class_names):
        if trg_class == src_class:
            continue
        x_, y_, k_ = syndata[(src_class, trg_class)]
        x.append(x_)
        y.append(y_)
        k.append(k_)

    x = np.concatenate(x, axis=0)
    y = np.concatenate(y, axis=0)
    k = np.concatenate(k, axis=0)

    x = x[k == domain+num_train_domains]
    y = y[k == domain+num_train_domains]
    k = k[k == domain+num_train_domains]

    return x, y, k



def get_fs_data(dataset, src_class, domain):

    if dataset == 'realworld':
        dataset_name = 'realworld_128_3ch_4cl'
        class_names = ['WAL', 'RUN', 'CLD', 'CLU']
    elif dataset == 'cwru':
        dataset_name = 'cwru_256_3ch_5cl'
        class_names = ['IR', 'Ball', 'OR_centred', 'OR_orthogonal', 'OR_opposite']

    class_idx = class_names.index(src_class)

    # Load the dataset
    with open(f'data/{dataset_name}.pkl', 'rb') as f:
        x, y, k = pickle.load(f)

    with open(f'data/{dataset_name}_fs.pkl', 'rb') as f:
        fs = pickle.load(f)

    x = x[fs == 1]
    y = y[fs == 1]
    k = k[fs == 1]

    x_ = x[(y != class_idx) & (k == domain)]
    y_ = y[(y != class_idx) & (k == domain)]
    k_ = k[(y != class_idx) & (k == domain)]

    return x_, y_, k_


def get_dp_data(dataset_name, class_idx, domain):

    # Load the dataset
    with open(f'data/{dataset_name}.pkl', 'rb') as f:
        x, y, k = pickle.load(f)

    with open(f'data/{dataset_name}_fs.pkl', 'rb') as f:
        fs = pickle.load(f)

    # Filter out the samples that are used for finetuning
    x = x[fs == 0]
    y = y[fs == 0]
    k = k[fs == 0]

    x_ = x[(y != class_idx) & (k == domain)]
    y_ = y[(y != class_idx) & (k == domain)]
    k_ = k[(y != class_idx) & (k == domain)]

    return x_, y_, k_


def get_df_data(dataset_name, class_idx):

    # Load the dataset
    with open(f'data/{dataset_name}.pkl', 'rb') as f:
        x, y, k = pickle.load(f)

    with open(f'data/{dataset_name}_fs.pkl', 'rb') as f:
        fs = pickle.load(f)

    # Filter out the samples that are used for finetuning
    x = x[fs == 0]
    y = y[fs == 0]
    k = k[fs == 0]

    x_ = x[y != class_idx]
    y_ = y[y != class_idx]
    k_ = k[y != class_idx]

    return x_, y_, k_



def save_score(accuracy, loss, source, domain, name, dataset, mode):
    eval_dir = 'bounds_fs'
    # Ensure the directory exists
    os.makedirs(eval_dir, exist_ok=True)
    # Path to the CSV file
    file_path = os.path.join(eval_dir, f'{name}_{dataset}.csv')
    # Check if the file exists
    file_exists = os.path.exists(file_path)

    # Open the file in append mode if it exists, or write mode if it doesn't
    with open(file_path, mode='a' if file_exists else 'w', newline='') as file:
        writer = csv.writer(file)
        # If the file does not exist, write the header
        if not file_exists:
            writer.writerow(['mode', 'source', 'domain', 'accuracy', 'loss'])
        # Write the data rows
        writer.writerow([mode, source, domain, accuracy, loss])




def remap_labels(y):
    label_map = {clss: i for i, clss in enumerate(np.unique(y))}
    return np.array([label_map[clss] for clss in y])


def setup_training(x_tr, y_tr, x_val, y_val, batch_size=64):
    # Convert numpy arrays to torch tensors
    x_train_tensor = torch.tensor(x_tr, dtype=torch.float32)
    y_train_tensor = torch.tensor(y_tr, dtype=torch.long)
    x_val_tensor = torch.tensor(x_val, dtype=torch.float32)
    y_val_tensor = torch.tensor(y_val, dtype=torch.long)

    # Create datasets and loaders
    train_dataset = TensorDataset(x_train_tensor, y_train_tensor)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_dataset = TensorDataset(x_val_tensor, y_val_tensor)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

    return train_loader, val_loader


def train_model(model, train_loader, val_loader, loss_fn, optimizer, epochs=300):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

    loss_train = []
    loss_val = []
    accuracy_val = []
    best_loss = np.inf
    best_accuracy = 0

    # Set up linear learning rate decay
    lambda_lr = lambda epoch: 1 - epoch / epochs
    scheduler = LambdaLR(optimizer, lr_lambda=lambda_lr)

    for epoch in range(epochs):
        model.train()
        total_loss = 0
        for x_batch, y_batch in train_loader:
            x_batch, y_batch = x_batch.to(device), y_batch.to(device)
            optimizer.zero_grad()
            outputs = model(x_batch)
            loss = loss_fn(outputs, y_batch)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        total_loss /= len(train_loader)
        loss_train.append(total_loss)

        # Update learning rate
        scheduler.step()

        val_accuracy, val_loss = evaluate_model(model, val_loader, loss_fn)
        if val_accuracy >= best_accuracy:
            best_epoch = epoch
            best_accuracy = val_accuracy
            best_loss = val_loss
            best_model_state = model.state_dict().copy()
        loss_val.append(val_loss)
        accuracy_val.append(val_accuracy)

        current_lr = scheduler.get_last_lr()[0]
        print(f"Epoch {epoch + 1}/{epochs} - Train loss: {total_loss:.4f} - Val loss: {val_loss:.4f} - Val accuracy: {val_accuracy:.4f} - LR: {current_lr:.8f}")

    print(f"Best epoch: {best_epoch + 1} - Best val accuracy: {best_accuracy:.4f} - Best val loss: {best_loss:.4f}")

    return best_model_state


def evaluate_model(model, test_loader, loss_fn):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    model.eval()
    total_loss = 0
    correct_predictions = 0
    total_predictions = 0

    with torch.no_grad():
        for x_batch, y_batch in test_loader:
            x_batch, y_batch = x_batch.to(device), y_batch.to(device)
            outputs = model(x_batch)
            loss = loss_fn(outputs, y_batch)
            total_loss += loss.item()

            _, predicted_labels = torch.max(outputs, 1)
            correct_predictions += (predicted_labels == y_batch).sum().item()
            total_predictions += len(y_batch)

    total_loss /= len(test_loader)
    accuracy = correct_predictions / total_predictions

    return accuracy, total_loss


def fine_tune_model(best_model, x_finetune, y_finetune):
    x_tr, x_val, y_tr, y_val = train_test_split(x_finetune, y_finetune, test_size=0.2, random_state=2710, stratify=y_finetune, shuffle=True)
    tr_loader, val_loader = setup_training(x_tr, y_tr, x_val, y_val, batch_size=64)

    loss_fn = nn.CrossEntropyLoss()
    initial_lr = 1e-5
    optimizer = optim.Adam(best_model.parameters(), lr=initial_lr)

    fine_tuned_model_state = train_model(best_model, tr_loader, val_loader, loss_fn, optimizer, epochs=50)

    return fine_tuned_model_state


def calculate_tstr_score(x_train, y_train, x_finetune, y_finetune, x_test, y_test, df_model):
    assert np.array_equal(np.unique(y_train), np.unique(y_test)), f"Train and test labels do not match: {np.unique(y_train)} vs {np.unique(y_test)}"
    assert np.array_equal(np.unique(y_train), np.unique(y_finetune)), f"Train and finetune labels do not match: {np.unique(y_train)} vs {np.unique(y_finetune)}"
    assert np.array_equal(np.unique(y_finetune), np.unique(y_test)), f"Finetune and test labels do not match: {np.unique(y_finetune)} vs {np.unique(y_test)}"

    # Remap labels
    y_train = remap_labels(y_train)
    y_finetune = remap_labels(y_finetune)
    y_test = remap_labels(y_test)

    fine_tuned_model_state_1 = fine_tune_model(df_model, x_train, y_train)
    fine_tuned_model_1 = TSTRClassifier(num_timesteps=x_train.shape[2], num_channels=x_train.shape[1], num_classes=len(np.unique(y_train)))
    fine_tuned_model_1.load_state_dict(fine_tuned_model_state_1)

    fine_tuned_model_state_2 = fine_tune_model(fine_tuned_model_1, x_finetune, y_finetune)
    fine_tuned_model_2 = TSTRClassifier(num_timesteps=x_train.shape[2], num_channels=x_train.shape[1], num_classes=len(np.unique(y_train)))
    fine_tuned_model_2.load_state_dict(fine_tuned_model_state_2)

    x_test_tensor = torch.tensor(x_test, dtype=torch.float32)
    y_test_tensor = torch.tensor(y_test, dtype=torch.long)
    test_dataset = TensorDataset(x_test_tensor, y_test_tensor)
    test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

    test_accuracy, test_loss = evaluate_model(fine_tuned_model_2, test_loader, nn.CrossEntropyLoss())

    return test_accuracy, test_loss


def train_df_model(x_train, y_train):
    y_train = remap_labels(y_train)
    x_tr, x_val, y_tr, y_val = train_test_split(x_train, y_train, test_size=0.2, random_state=2710, stratify=y_train, shuffle=True)
    tr_loader, val_loader = setup_training(x_tr, y_tr, x_val, y_val, batch_size=64)

    model = TSTRClassifier(num_timesteps=x_train.shape[2], num_channels=x_train.shape[1], num_classes=len(np.unique(y_train)))
    loss_fn = nn.CrossEntropyLoss()
    initial_lr = 0.0001
    optimizer = optim.Adam(model.parameters(), lr=initial_lr)

    best_model_state = train_model(model, tr_loader, val_loader, loss_fn, optimizer, epochs=50)
    best_model = TSTRClassifier(num_timesteps=x_train.shape[2], num_channels=x_train.shape[1], num_classes=len(np.unique(y_train)))
    best_model.load_state_dict(best_model_state)

    return best_model


In [4]:
dataset = 'realworld'

if dataset == 'realworld':
    dataset_name = 'realworld_128_3ch_4cl'
    num_df_domains = 10
    num_dp_domains = 5
    num_classes = 4
    class_names = ['WAL', 'RUN', 'CLD', 'CLU']

elif dataset == 'cwru':
    dataset_name = 'cwru_256_3ch_5cl'
    num_df_domains = 4
    num_dp_domains = 4
    num_classes = 5
    class_names = ['IR', 'Ball', 'OR_centred', 'OR_orthogonal', 'OR_opposite']

classes_dict = {clss: i for i, clss in enumerate(class_names)}

accs_lat = []
accs_ref = []

for src_class in class_names:
    trg_classes = [clss for clss in class_names if clss != src_class]
    src_class_idx = classes_dict[src_class]

    x_df, y_df, k_df = get_df_data(dataset_name, src_class_idx)
    df_model = train_df_model(x_df, y_df)

    for domain in range(num_dp_domains):
        x_syn_lat, y_syn_lat, k_syn_lat = get_syn_data(dataset, src_class, domain, mode='lat')
        x_syn_ref, y_syn_ref, k_syn_ref = get_syn_data(dataset, src_class, domain, mode='ref')
        x_fs, y_fs, k_fs = get_fs_data(dataset, src_class, domain)
        x_dp, y_dp, k_dp = get_dp_data(dataset_name, src_class_idx, domain)

        acc_lat, loss_lat = calculate_tstr_score(x_syn_lat, y_syn_lat, x_fs, y_fs, x_dp, y_dp, df_model)
        save_score(acc_lat, loss_lat, src_class, domain+num_df_domains, 'TSTR_FT2_fs', dataset, 'lat')
        print(f'{src_class} - {domain} - Latent: {acc_lat:.2f}\n')
        accs_lat.append(acc_lat)
        acc_ref, loss_ref = calculate_tstr_score(x_syn_ref, y_syn_ref, x_fs, y_fs, x_dp, y_dp, df_model)
        save_score(acc_ref, loss_ref, src_class, domain+num_df_domains, 'TSTR_FT2_fs', dataset, 'ref')
        print(f'{src_class} - {domain} - Reference: {acc_ref:.2f}\n')
        accs_ref.append(acc_ref)

print(f'Latent: {np.mean(accs_lat):.2f}')
print(f'Ref: {np.mean(accs_ref):.2f}')

Epoch 1/50 - Train loss: 0.4143 - Val loss: 0.1752 - Val accuracy: 0.9510 - LR: 0.00009800
Epoch 2/50 - Train loss: 0.1389 - Val loss: 0.1198 - Val accuracy: 0.9617 - LR: 0.00009600
Epoch 3/50 - Train loss: 0.0995 - Val loss: 0.1011 - Val accuracy: 0.9661 - LR: 0.00009400
Epoch 4/50 - Train loss: 0.0878 - Val loss: 0.0898 - Val accuracy: 0.9717 - LR: 0.00009200
Epoch 5/50 - Train loss: 0.0704 - Val loss: 0.0795 - Val accuracy: 0.9742 - LR: 0.00009000
Epoch 6/50 - Train loss: 0.0596 - Val loss: 0.0693 - Val accuracy: 0.9768 - LR: 0.00008800
Epoch 7/50 - Train loss: 0.0530 - Val loss: 0.0629 - Val accuracy: 0.9749 - LR: 0.00008600
Epoch 8/50 - Train loss: 0.0499 - Val loss: 0.0714 - Val accuracy: 0.9780 - LR: 0.00008400
Epoch 9/50 - Train loss: 0.0447 - Val loss: 0.0556 - Val accuracy: 0.9793 - LR: 0.00008200
Epoch 10/50 - Train loss: 0.0383 - Val loss: 0.0509 - Val accuracy: 0.9805 - LR: 0.00008000
Epoch 11/50 - Train loss: 0.0322 - Val loss: 0.0517 - Val accuracy: 0.9818 - LR: 0.000078

In [5]:
dataset = 'cwru'

if dataset == 'realworld':
    dataset_name = 'realworld_128_3ch_4cl'
    num_df_domains = 10
    num_dp_domains = 5
    num_classes = 4
    class_names = ['WAL', 'RUN', 'CLD', 'CLU']

elif dataset == 'cwru':
    dataset_name = 'cwru_256_3ch_5cl'
    num_df_domains = 4
    num_dp_domains = 4
    num_classes = 5
    class_names = ['IR', 'Ball', 'OR_centred', 'OR_orthogonal', 'OR_opposite']

classes_dict = {clss: i for i, clss in enumerate(class_names)}

accs_lat = []
accs_ref = []

for src_class in class_names:
    trg_classes = [clss for clss in class_names if clss != src_class]
    src_class_idx = classes_dict[src_class]

    x_df, y_df, k_df = get_df_data(dataset_name, src_class_idx)
    df_model = train_df_model(x_df, y_df)

    for domain in range(num_dp_domains):
        x_syn_lat, y_syn_lat, k_syn_lat = get_syn_data(dataset, src_class, domain, mode='lat')
        x_syn_ref, y_syn_ref, k_syn_ref = get_syn_data(dataset, src_class, domain, mode='ref')
        x_fs, y_fs, k_fs = get_fs_data(dataset, src_class, domain)
        x_dp, y_dp, k_dp = get_dp_data(dataset_name, src_class_idx, domain)

        acc_lat, loss_lat = calculate_tstr_score(x_syn_lat, y_syn_lat, x_fs, y_fs, x_dp, y_dp, df_model)
        save_score(acc_lat, loss_lat, src_class, domain+num_df_domains, 'TSTR_FT3_fs', dataset, 'lat')
        print(f'{src_class} - {domain} - Latent: {acc_lat:.2f}\n')
        accs_lat.append(acc_lat)
        acc_ref, loss_ref = calculate_tstr_score(x_syn_ref, y_syn_ref, x_fs, y_fs, x_dp, y_dp, df_model)
        save_score(acc_ref, loss_ref, src_class, domain+num_df_domains, 'TSTR_FT3_fs', dataset, 'ref')
        print(f'{src_class} - {domain} - Reference: {acc_ref:.2f}\n')
        accs_ref.append(acc_ref)

print(f'Latent: {np.mean(accs_lat):.2f}')
print(f'Ref: {np.mean(accs_ref):.2f}')

Epoch 1/50 - Train loss: 0.5278 - Val loss: 0.2086 - Val accuracy: 0.9148 - LR: 0.00009800
Epoch 2/50 - Train loss: 0.1280 - Val loss: 0.1291 - Val accuracy: 0.9516 - LR: 0.00009600
Epoch 3/50 - Train loss: 0.0693 - Val loss: 0.1336 - Val accuracy: 0.9426 - LR: 0.00009400
Epoch 4/50 - Train loss: 0.0523 - Val loss: 0.0468 - Val accuracy: 0.9824 - LR: 0.00009200
Epoch 5/50 - Train loss: 0.0380 - Val loss: 0.0552 - Val accuracy: 0.9784 - LR: 0.00009000
Epoch 6/50 - Train loss: 0.0276 - Val loss: 0.0259 - Val accuracy: 0.9914 - LR: 0.00008800
Epoch 7/50 - Train loss: 0.0236 - Val loss: 0.0801 - Val accuracy: 0.9675 - LR: 0.00008600
Epoch 8/50 - Train loss: 0.0168 - Val loss: 0.0145 - Val accuracy: 0.9967 - LR: 0.00008400
Epoch 9/50 - Train loss: 0.0147 - Val loss: 0.0258 - Val accuracy: 0.9901 - LR: 0.00008200
Epoch 10/50 - Train loss: 0.0125 - Val loss: 0.0120 - Val accuracy: 0.9967 - LR: 0.00008000
Epoch 11/50 - Train loss: 0.0089 - Val loss: 0.0190 - Val accuracy: 0.9937 - LR: 0.000078