# Ex-vivo Raman Spectroscopy and AI-based Classification of Soft Tissue Sarcomas 

This notebook demonstrates the training and evaluation of a ResNet-based model for classifying Raman spectroscopic data of different tissue types. The model is designed to distinguish between various tissue types including Normal, Benign, and Malignant tissues.

## Setup and Data Preparation

The following cells import necessary libraries, load metadata, and prepare the dataset for training.

In [None]:
import os
import pandas as pd
import numpy as np

from config import (
    METADATA_DIR,
    DATASET_DIR,
    TRAIN_DATA_DIR,
    TEST_DATA_DIR,
    MODELS_DIR,
    metadata
)
# Reload 
%reload_ext autoreload
%autoreload 2

## Load and Examine Metadata

The metadata contains information about the samples, including patient IDs (from 1 to 7), tissue types, and clinical diagnoses.

In [2]:
metadata

Unnamed: 0,Idx,Age,Sex,Location,Final Diagnosis,ID_Type,Tissue Classification
0,P1,86,F,left thigh,Skin layers,SKN,Normal
1,P1,86,F,left thigh,High-grade pleomorphic liposarcoma,PLS,Malignant
2,P1,86,F,left thigh,Muscle,MSC,Normal
3,P2,70,F,right thigh,Fat,FAT,Normal
4,P2,70,F,right thigh,Muscle,MSC,Normal
5,P3,66,F,right thigh,High-grade pleomorphic liposarcoma,PLS,Malignant
6,P3,66,F,right thigh,Muscle,MSC,Normal
7,P3,66,F,right thigh,Skin layers,SKN,Normal
8,P3,66,F,right thigh,Fat,FAT,Normal
9,P4,37,M,right knee,Skin layers,SKN,Normal


## Data Structure Creation

The `data_utils` module provides functions for creating the data structure, including splitting the dataset into training and testing sets. Download the dataset from the provided link and extract it into the `./data/dataset` directory. The function `create_data_structure` will use the dataset downloaded in `./data/dataset` and create test and train datasets with the following parameters:

- `data_dir=DATASET_DIR` (directory of downloaded dataset)
- `train_dir=TRAIN_DATA_DIR` (directory for train dataset)
- `test_dir=TEST_DATA_DIR` (directory for test dataset)


In [None]:
import utils.data_utils as du
# du.create_data_structure(metadata=metadata,
#                          data_dir=DATASET_DIR,
#                          train_data_dir=TRAIN_DATA_DIR,
#                          test_data_dir=TEST_DATA_DIR,
#                          test_size=0.2,
#                          random_state=42, 
#                          clear_dirs=True)

## Dataset Loading and Model Initialization

Here, we import the ResNet model architecture and set up PyTorch DataLoaders for training and testing. The `RamanDataset` class loads spectral data from the training and testing directories.

In [None]:
from models.resnet import *
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")



train_dir = TRAIN_DATA_DIR      # this uses files in train folder
test_dir = TEST_DATA_DIR        # this uses files in test folder



train_data = RamanDataset(train_dir)
train_loader = DataLoader(train_data, batch_size=128, shuffle=True, num_workers=8, pin_memory=True)

test_data = RamanDataset(test_dir)
test_loader = DataLoader(test_data, batch_size=128, shuffle=False, num_workers=8, pin_memory=True)


In [None]:
# print the train, test and aggregate spectra number
print(f"Number of train spectra: {train_data.class_counts}")
print(f"Number of test spectra: {test_data.class_counts}")

print("Class to index mapping:")
print(train_data.class_to_idx)

print(f"Total samples: {len(train_data) + len(test_data)}")

## Class Weight Calculation

To address class imbalance (some tissue types may have more samples than others), we calculate class weights inversely proportional to their frequency in the training set. These weights will be used in the loss function to give more importance to underrepresented classes during training.

In [None]:

# Calculate class weights
total_samples = len(train_data)
class_weights = [total_samples / count for count in train_data.class_counts.values()]
# normalize class weights
class_weights = [weight / max(class_weights) for weight in class_weights]
# class_weights = [0.5 if weight < 0.5 else 1 for weight in class_weights]
class_weights = torch.FloatTensor(class_weights).to(device)


class_weights

## Model Training

This cell performs the full model training process:
1. Initializes the ResNet model architecture
2. Sets up the loss function (CrossEntropyLoss) with the calculated class weights
3. Configures the optimizer (SGD with momentum) and learning rate scheduler (CyclicLR)
4. Trains the model for 50 epochs, saving the best checkpoint based on validation accuracy
5. Plots the training/validation loss and accuracy curves to visualize the training progress

In [None]:
import time 

num_classes = len(train_data.class_to_idx)
model = RamanResNet(num_classes)




criterion = nn.CrossEntropyLoss(weight=class_weights)
optimizer = optim.SGD(model.parameters(), lr=0.002, momentum=0.9)

# define model checkpoint by today's date and time
timestr = time.strftime("%m%d%Y-%H%M%S")
model_path = os.path.join(MODELS_DIR, 'checkpoints',f"raman_resnet_{timestr}.pth")
scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer, 
                                              base_lr=0.001, 
                                              max_lr=0.01, 
                                              step_size_up=8 * 1792, 
                                              mode='triangular')


train_losses, val_losses, train_accuracies, val_accuracies = train_model(
    model, train_loader, test_loader, 
    criterion, optimizer, num_epochs=50, 
    model_path=model_path, scheduler=scheduler)

# Plot the results
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.plot(train_losses, label='Train Loss')
plt.plot(val_losses, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Loss over Epochs')
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(train_accuracies, label='Train Accuracy')
plt.plot(val_accuracies, label='Validation Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy (%)')
plt.title('Accuracy over Epochs')
plt.legend()

plt.tight_layout()
plt.show()

## Conclusion

The model has been trained on Raman spectroscopy data to classify different tissue types. The saved model checkpoint can be used for evaluation on test data, as demonstrated in the `eval.ipynb` notebook.

If you find this helpful, please consider citing our paper:

```Boroji, M., Danesh, V., Barrera, D., Lee, E., Arauz, P., Farrell, R., Boyce, B., Khan, F., and Kao, I. "Ex-Vivo Raman Spectroscopy and AI-Based Classification of Soft Tissue Sarcomas", PLOS ONE, Public Library of Science, 2025```
