In [75]:
import torch
from torch.utils.data import Dataset, DataLoader
from pathlib import Path
import pandas as pd
import numpy as np
from dataclasses import dataclass

In [76]:
from torch.utils.data import Dataset, DataLoader
from pathlib import Path
import pandas as pd
import numpy as np
from dataclasses import dataclass
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm


# Function to pad an array to a specific shape
def to_shape(a, shape):
    # Unpack the target shape
    y_, x_ = shape

    # Get the current shape of the array
    y, x = a.shape

    # Calculate the padding needed in the y and x directions
    y_pad = y_ - y
    x_pad = x_ - x
    output = np.zeros()
    # Pad the array using numpy's pad function
    return np.pad(
        a,
        [(0, 1), (0, 1)],
        # Calculate the padding for each dimension
        # ((y_pad // 2, y_pad // 2 + y_pad % 2), (x_pad // 2, x_pad // 2 + x_pad % 2)),
        mode="constant",
    )


# Function to pad data and labels to a specific shape
def apply_padding(data_df, N, T_max):
    # Define the final shape of the data and labels
    final_shape = (N, T_max, 3)

    # Initialize the final data and labels with zeros
    final_data = np.zeros(final_shape)
    final_label = np.zeros((N, T_max, 3))

    # Select a random subset of trajectory indices
    if len(data_df["traj_idx"].unique()) < N:
        selected_ids = np.random.choice(
            data_df["traj_idx"].unique(), size=N, replace=True
        )
    else:
        selected_ids = np.random.choice(
            data_df["traj_idx"].unique(), size=N, replace=False
        )

    # Iterate over the selected trajectory indices
    for n, id in enumerate(selected_ids):
        # Filter the data for the current trajectory index
        exp = data_df[data_df["traj_idx"] == id]

        # Extract the data and labels for the current trajectory
        data = exp[["frame", "x", "y"]].to_numpy()
        data[:, 0] = data[:, 0] - data[0, 0] + 1  # putting first frame rate to 1
        data[:, 1] = data[:, 1] - data[0, 1]  # putting initial position to 0
        data[:, 2] = (
            data[:, 2] - data[0, 2]
        )  # putting initital position to 0        # print(exp["frame"])
        label = exp[["alpha", "D", "state"]].to_numpy()
        ## adding one to the states
        label[:, 2] = label[:, 2] + 1
        # If the data is longer than T_max, truncate it
        if data.shape[0] > T_max:
            final_data[n, :, :] = data[:T_max, :]
            final_label[n, :, :] = label[:T_max, :]

        # Otherwise, pad the data to T_max
        else:
            # print((label.shape, T_max))
            final_data[n, : data.shape[0], :] = data
            final_label[n, : data.shape[0], :] = label

    # Return the padded data and labels
    return final_data, final_label


# Define a function to normalize data
def normalize_df(data):
    # Calculate displacement in x and y directions
    # Normalize by substring mean and dividing by variance.

    displacement_x = []
    displacement_y = []
    for _, group in data.groupby("traj_idx"):
        x = np.asarray(group["x"])
        y = np.asarray(group["y"])
        d_x = x[1:] - x[:-1]
        d_y = y[1:] - y[:-1]
        displacement_x = displacement_x + list(d_x)
        displacement_y = displacement_y + list(d_y)

    # Calculate variance in x and y directions
    variance_x = np.sqrt(np.std(displacement_x))
    variance_y = np.sqrt(np.std(displacement_y))

    # Normalize data
    data.loc[:, "x"] = (data["x"] - data["x"].mean()) / variance_x
    data.loc[:, "y"] = (data["y"] - data["y"].mean()) / variance_y


def normalize_np(data):

    displacement_x = []
    displacement_y = []
    for n in range(data.shape[0]):
        x = data[n, :, 1]
        y = data[n, :, 2]
        d_x = x[1:] - x[:-1]
        d_y = y[1:] - y[:-1]
        displacement_x = displacement_x + list(d_x)
        displacement_y = displacement_y + list(d_y)

    # Calculate variance in x and y directions
    variance_x = np.sqrt(np.std(displacement_x))
    variance_y = np.sqrt(np.std(displacement_y))

    # Normalize data

    data[:, :, 1] = (data[:, :, 1] - np.mean(data[:, :, 1])) / variance_x
    data[:, :, 2] = (data[:, :, 2] - np.mean(data[:, :, 2])) / variance_x

    return data


# Define a function to list directory tree with pathlib
def list_directory_tree_with_pathlib(starting_directory):
    path_object = Path(starting_directory)
    folders = []
    for file_path in path_object.rglob("*.csv"):
        folders.append(file_path)
    return folders


# Define a custom dataset class for all data
@dataclass
class Dataset_all_data(Dataset):
    # Initialize filenames and transform flag
    # Pad value should be a tuple such as (N, Tmax)
    filenames: list
    transform: bool = False
    pad: None | tuple = None
    noise: bool = False

    def __len__(self):
        # Return the number of files
        return len(self.filenames)

    def __getitem__(self, idx):
        # Read csv file and extract data and label
        df = pd.read_csv(self.filenames[idx])

        if self.pad is None:
            data = df[["traj_idx", "frame", "x", "y"]]
            label = np.asarray(df[["alpha", "D"]])
            label_2 = np.asarray(df["state"])

        else:
            if len(self.pad) != 2:
                raise ValueError("pad value should be set as (N, T_max)")
            data, label = apply_padding(df, *self.pad)
            data = data[:, :, 1:]  ## Removing the frame column
            label_2 = label[:, :, -1]
            label_2[label_2[:, :] > 0] = label_2[label_2[:, :] > 0]
            label = label[:, :, :-1]

        # Normalize data if transform flag is True
        if self.transform:
            if self.pad is None:
                normalize_df(data)
                data = np.asarray(data)
            else:
                data = normalize_np(data)

        if self.noise:
            data = add_noise(data)

        # Normalize D between 0 and 1

        # label[:,:,1][label[:,:,1] != 0] = np.log(label[:,:,1][label[:,:,1] != 0]) #- np.log(1e-6)) #/   (np.log(1e12) - np.log(1e-6))
        # label = label[:,:,1]
        label_regression = np.zeros((label.shape[0], 2))

        # print(np.unique(label_2))

        for i in range(label.shape[0]):
            K = np.unique(label[i, :, 1][label[i, :, 1] != 0])
            if len(K) == 2:
                label_regression[i, :] = K

                if label[i, 0, 1] != label_regression[i, 0]:
                    label_regression[i, :] = label_regression[i, ::-1]

            elif len(K) == 1:
                states = label_2[i, :]
                if 1 in states:
                    # print(np.unique(states))
                    if states[0] == 1:
                        label_regression[i, :] = [0, K[0]]
                    else:
                        label_regression[i, :] = [K[0], 0]

                    # print(label_regression[i,:])

                else:
                    label_regression[i, :] = [K[0], K[0]]

            else:
                if np.unique(label[i, :, 1]) == 0:
                    label_regression[i, :] = [0, 0]
                else:

                    # print(np.unique(label[i,:,1]))

                    # print(Ds)
                    raise Exception("more than 2 diffusions")

        label_segmentation = np.zeros((label_2.shape[0], label_2.shape[1]))

        for i in range(label.shape[0]):
            if label_regression[i, 0] == label_regression[i, 1]:
                position = label[i, :, 1] == label_regression[i, 0]
                label_segmentation[i, position] = 1
            else:

                position_1 = label[i, :, 1] == label_regression[i, 0]
                position_2 = label[i, :, 1] == label_regression[i, 1]

                label_segmentation[i, position_1] = 1
                label_segmentation[i, position_2] = 2

        return torch.from_numpy(data.astype(np.float32)), torch.from_numpy(
            label_segmentation.astype(np.float32)
        )
        # torch.from_numpy(label_2.astype(np.float32)),


def add_noise(data):
    noise_amplitude = np.random.choice(
        [
            0.01,
            0.1,
        ]
    )
    noise = np.random.normal(0, noise_amplitude, data[:, :, :].shape)
    data[:, :, :][data[:, :, 1:] != 0] = (
        data[:, :, :][data[:, :, 1:] != 0] + data[:, :, :][data[:, :, 1:] != 0] * noise
    )
    return data

In [81]:
all_data_set = list_directory_tree_with_pathlib(
    r"/home/m.lavaud/Documents/dataset",
)
np.random.shuffle(all_data_set)

In [82]:
training_dataset = Dataset_all_data(all_data_set[:2000], transform=False, pad=(20, 200))

In [83]:
test = iter(training_dataset)

In [84]:
a = next(test)[1]

20.0

In [91]:
dataloader = DataLoader(training_dataset, shuffle=True, batch_size=10, num_workers=0)

In [92]:
from einops import rearrange

In [93]:
from mamba_ssm import Mamba

  from .autonotebook import tqdm as notebook_tqdm


In [103]:
class segmentation_model(nn.Module):
    def __init__(self, d_model, d_state, d_conv, expand, dropout=0.2, device="cuda"):
        super().__init__()
        self.device = device

        self.mamba = Mamba(
            d_model=d_model, d_state=d_state, d_conv=d_conv, expand=expand
        ).to(device)
        self.dropout = nn.Dropout(p=dropout).to(device)
        self.fc = nn.Linear(in_features=d_model, out_features=3).to(device)
        self.softplus = nn.Softplus()

    def forward(self, input):

        mamba_out = self.mamba(input)
        mamba_out = self.dropout(mamba_out)
        out = self.fc(mamba_out)

        return out  # No activation here ! It is done by the cross entropy loss

In [117]:
model = segmentation_model(d_model=2, d_state=1, d_conv=4, expand=4)

In [118]:
optimizer = optim.Adam(model.parameters(), lr=1e-3)

In [119]:
max_epoch = 10
total_running_loss = []
for epoch in range(max_epoch):
    running_regresssion_loss = []
    with tqdm(dataloader, unit="batch") as tepoch:
        model.train()

        for inputs, classification_targets in tepoch:

            tepoch.set_description(f"Epoch {epoch}")

            inputs = inputs.to("cuda")
            inputs = torch.flatten(inputs, start_dim=0, end_dim=1)

            classification_targets = (
                torch.flatten(
                    classification_targets,
                    start_dim=0,
                    end_dim=1,
                )
                .type(torch.LongTensor)
                .to("cuda")
            )

            optimizer.zero_grad()

            classification_output = model(inputs)
            classification_output = torch.squeeze(classification_output)

            ## Computation of the weight of the classes

            counts = torch.unique(classification_targets, return_counts=True)[1][1:]
            weights = torch.sum(counts) / (2 * counts)
            weights = weights.to("cpu")  # ignoring the first class
            weight = torch.zeros(3)
            weight[1:] = weights
            ###

            classification_criterion = nn.CrossEntropyLoss(
                weight=weight, ignore_index=0
            )

            classification_loss = classification_criterion(
                classification_output.view(-1, 3).to("cpu"),
                classification_targets.view(-1).to("cpu"),
            )
            # stop
            classification_loss.backward()
            optimizer.step()

            tepoch.set_postfix(
                loss=classification_loss.item(),
                s1=torch.sum(classification_targets == 1).item(),
                s2=torch.sum(classification_targets == 2).item(),
                w1 = weight[1].item(),
                w2 = weight[2].item(),
            )

            running_regresssion_loss.append(classification_loss.item())
        total_running_loss.append(np.mean(running_regresssion_loss))

Epoch 0: 100%|██████████| 200/200 [00:38<00:00,  5.19batch/s, loss=1.22, s1=31622, s2=1194, w1=0.519, w2=13.7] 
Epoch 1: 100%|██████████| 200/200 [00:43<00:00,  4.58batch/s, loss=1.01, s1=28839, s2=2294, w1=0.54, w2=6.79]  
Epoch 2: 100%|██████████| 200/200 [00:43<00:00,  4.59batch/s, loss=0.894, s1=35511, s2=621, w1=0.509, w2=29.1] 
Epoch 3: 100%|██████████| 200/200 [00:41<00:00,  4.78batch/s, loss=0.836, s1=32666, s2=401, w1=0.506, w2=41.2] 
Epoch 4: 100%|██████████| 200/200 [00:41<00:00,  4.81batch/s, loss=0.903, s1=34807, s2=804, w1=0.512, w2=22.1] 
Epoch 5: 100%|██████████| 200/200 [00:43<00:00,  4.58batch/s, loss=0.755, s1=32418, s2=1782, w1=0.527, w2=9.6] 
Epoch 6: 100%|██████████| 200/200 [00:42<00:00,  4.65batch/s, loss=0.756, s1=33660, s2=266, w1=0.504, w2=63.8] 
Epoch 7: 100%|██████████| 200/200 [00:42<00:00,  4.68batch/s, loss=0.716, s1=33366, s2=814, w1=0.512, w2=21]   
Epoch 8: 100%|██████████| 200/200 [00:41<00:00,  4.84batch/s, loss=0.723, s1=31492, s2=772, w1=0.512, w2

In [120]:
total_running_loss

[3.358871024250984,
 1.4566502141952515,
 1.0073384791612625,
 0.9063091453909874,
 0.8509556940197944,
 0.8101608544588089,
 0.7790414336323738,
 0.7641315221786499,
 0.7396654093265533,
 0.731806293129921]

In [116]:
total_running_loss

[67.42232237935066,
 51.4404145938158,
 47.10895276725292,
 45.17216860175133,
 41.39366302847862,
 45.67685935378075,
 46.89691228687763,
 34.27214115440845,
 25.95659221470356,
 37.28627359747887]