<a href="https://colab.research.google.com/github/AnuragSharma5893/video-to-cryo-ET-data/blob/main/Transfer_Learning_from_video_to_cryo_ET_data.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Cell 1: Installations
!pip install -q mrcfile transformers huggingface_hub

In [None]:
# Full Analysis Pipeline Code

import os
import glob
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
import mrcfile
from sklearn.manifold import TSNE
from sklearn.cluster import KMeans
import matplotlib.pyplot as plt
from scipy.ndimage import zoom
from functools import partial
from PIL import Image

# Suppress warnings for cleaner output
import warnings
warnings.filterwarnings("ignore")

# --- 3D-ResNet34 Model Definition (from kenshohara/video-classification-3d-cnn-pytorch) ---
def conv3x3x3(in_planes, out_planes, stride=1):
    return nn.Conv3d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)

def conv1x1x1(in_planes, out_planes, stride=1):
    return nn.Conv3d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)

class BasicBlock(nn.Module):
    expansion = 1
    def __init__(self, in_planes, planes, stride=1, downsample=None):
        super().__init__()
        self.conv1 = conv3x3x3(in_planes, planes, stride)
        self.bn1 = nn.BatchNorm3d(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3x3(planes, planes)
        self.bn2 = nn.BatchNorm3d(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        if self.downsample is not None:
            residual = self.downsample(x)
        out += residual
        out = self.relu(out)
        return out

class ResNet(nn.Module):
    def __init__(self, block, layers, block_inplanes, n_input_channels=3, conv1_t_size=7,
                 conv1_t_stride=1, no_max_pool=False, shortcut_type='B', widen_factor=1.0, n_classes=400):
        super().__init__()
        block_inplanes = [int(x * widen_factor) for x in block_inplanes]
        self.in_planes = block_inplanes[0]
        self.no_max_pool = no_max_pool
        self.conv1 = nn.Conv3d(n_input_channels, self.in_planes, kernel_size=(conv1_t_size, 7, 7),
                               stride=(conv1_t_stride, 2, 2), padding=(conv1_t_size // 2, 3, 3), bias=False)
        self.bn1 = nn.BatchNorm3d(self.in_planes)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool3d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, block_inplanes[0], layers[0], shortcut_type)
        self.layer2 = self._make_layer(block, block_inplanes[1], layers[1], shortcut_type, stride=2)
        self.layer3 = self._make_layer(block, block_inplanes[2], layers[2], shortcut_type, stride=2)
        self.layer4 = self._make_layer(block, block_inplanes[3], layers[3], shortcut_type, stride=2)
        self.avgpool = nn.AdaptiveAvgPool3d((1, 1, 1))
        self.fc = nn.Linear(block_inplanes[3] * block.expansion, 1000)  # Change n_classes to 1000
        for m in self.modules():
            if isinstance(m, nn.Conv3d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm3d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def _downsample_basic_block(self, x, planes, stride):
        out = F.avg_pool3d(x, kernel_size=1, stride=stride)
        zero_pads = torch.zeros(out.size(0), planes - out.size(1), out.size(2), out.size(3), out.size(4))
        if isinstance(out.data, torch.cuda.FloatTensor):
            zero_pads = zero_pads.cuda()
        out = torch.cat([out.data, zero_pads], dim=1)
        return out

    def _make_layer(self, block, planes, blocks, shortcut_type, stride=1):
        downsample = None
        if stride!= 1 or self.in_planes!= planes * block.expansion:
            if shortcut_type == 'A':
                downsample = partial(self._downsample_basic_block, planes=planes * block.expansion, stride=stride)
            else:
                downsample = nn.Sequential(
                    conv1x1x1(self.in_planes, planes * block.expansion, stride),
                    nn.BatchNorm3d(planes * block.expansion))
        layers = []
        layers.append(block(self.in_planes, planes, stride, downsample))
        self.in_planes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(self.in_planes, planes))
        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        if not self.no_max_pool:
            x = self.maxpool(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

def generate_resnet34(**kwargs):
    return ResNet(BasicBlock, [3, 4, 6, 3], [64, 128, 256, 512], **kwargs)

class ResNetFeatureExtractor(nn.Module):
    def __init__(self, resnet_model):
        super(ResNetFeatureExtractor, self).__init__()
        self.features = nn.Sequential(*list(resnet_model.children())[:-1])
        self.fc = resnet_model.fc
        self.fc_new = nn.Linear(1000, 400) # New FC layer

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        x = self.fc_new(x)
        return x

# --- VideoMAE imports ---
try:
    from transformers import VideoMAEModel, AutoImageProcessor
except ImportError:
    print("Transformers library not found. Please run the installation cell first.")
    VideoMAEModel = None
    AutoImageProcessor = None

# --- Data Loading and Preprocessing ---
def preprocess_subtomogram_resnet(mrc_data, target_shape=(96, 96, 96)):
    data = mrc_data.astype(np.float32)
    min_val, max_val = data.min(), data.max()
    if max_val > min_val:
        data = (data - min_val) / (max_val - min_val)
    zoom_factors = [t / s for t, s in zip(target_shape, data.shape)]
    resized_data = zoom(data, zoom_factors, order=1)
    # Replicate the single channel to three channels
    replicated_data = np.stack([resized_data]*3, axis=0)
    tensor_data = torch.from_numpy(replicated_data).unsqueeze(0) # Add batch dimension
    # Normalize each channel individually
    normalize = transforms.Normalize(mean=[0.5], std=[0.5])
    normalized_channels = [normalize(c.unsqueeze(0)) for c in tensor_data.squeeze(0)]
    return torch.cat(normalized_channels, dim=0).unsqueeze(0)


def preprocess_subtomogram_videomae(mrc_data, image_processor, num_frames=16):
    data = mrc_data.astype(np.float32)
    min_val, max_val = data.min(), data.max()
    if max_val > min_val:
        data = 255 * (data - min_val) / (max_val - min_val)
    data = data.astype(np.uint8)
    depth = data.shape[0]
    indices = np.linspace(0, depth - 1, num_frames, dtype=int)
    frames = [Image.fromarray(data[i]) for i in indices]
    inputs = image_processor(list(frames), return_tensors="pt")
    return inputs

In [None]:
# Cell to download the ResNet weights
import requests

url = 'https://download.pytorch.org/models/resnet34-333f7ec4.pth'
file_path = '/resnet-34-kinetics-cpu.pth'

try:
    print(f"Downloading {url} to {file_path}...")
    with requests.get(url, stream=True) as r:
        r.raise_for_status()
        with open(file_path, 'wb') as f:
            for chunk in r.iter_content(chunk_size=8192):
                f.write(chunk)
    print("Download complete.")
except requests.exceptions.RequestException as e:
    print(f"Error downloading file: {e}")

Downloading https://download.pytorch.org/models/resnet34-333f7ec4.pth to /resnet-34-kinetics-cpu.pth...
Download complete.


In [None]:
# Execute the Pipeline

# Define file paths
data_path = './cryo-ET-samples'
resnet_weights_path = '/resnet-34-kinetics-cpu.pth'
output_dir = './results'

# Create output directory if it doesn't exist
os.makedirs(output_dir, exist_ok=True)



In [None]:
# --- Feature Extraction and Analysis Functions ---

def extract_features_resnet(model, data_path):
    """
    Extracts features from .mrc files using the 3D-ResNet34 model.
    """
    model.eval()
    features = []
    subtomogram_ids = []
    mrc_files = sorted(glob.glob(os.path.join(data_path, '*.mrc')))
    print(f"Found {len(mrc_files)} .mrc files for ResNet processing.")
    with torch.no_grad():
        for f in mrc_files:
            try:
                with mrcfile.open(f, permissive=True) as mrc:
                    subtomogram_id = os.path.basename(f).split('.')[0]
                    preprocessed_data = preprocess_subtomogram_resnet(mrc.data)
                    feature = model(preprocessed_data)
                    features.append(feature.cpu().numpy().flatten())
                    subtomogram_ids.append(subtomogram_id)
            except Exception as e:
                print(f"Could not process file {f}: {e}")
    return np.array(features), subtomogram_ids

def extract_features_videomae(model, processor, data_path):
    """
    Extracts features from .mrc files using the VideoMAE model.
    """
    model.eval()
    features = []
    subtomogram_ids = []
    mrc_files = sorted(glob.glob(os.path.join(data_path, '*.mrc')))
    print(f"Found {len(mrc_files)} .mrc files for VideoMAE processing.")
    with torch.no_grad():
        for f in mrc_files:
            try:
                with mrcfile.open(f, permissive=True) as mrc:
                    subtomogram_id = os.path.basename(f).split('.')[0]
                    inputs = preprocess_subtomogram_videomae(mrc.data, processor)
                    outputs = model(**inputs)
                    feature = outputs.last_hidden_state.mean(dim=1)
                    features.append(feature.cpu().numpy().flatten())
                    subtomogram_ids.append(subtomogram_id)
            except Exception as e:
                print(f"Could not process file {f}: {e}")
    return np.array(features), subtomogram_ids

def run_tsne_kmeans_and_plot(features, ids, model_name, output_dir):
    """
    Performs t-SNE and K-Means clustering and plots the results.
    """
    print(f"Running t-SNE and K-Means for {model_name}...")
    tsne = TSNE(n_components=2, random_state=42, perplexity=min(5, len(features)-1))
    tsne_results = tsne.fit_transform(features)

    kmeans = KMeans(n_clusters=4, random_state=42, n_init=10)
    clusters = kmeans.fit_predict(features)

    plt.figure(figsize=(10, 8))
    scatter = plt.scatter(tsne_results[:, 0], tsne_results[:, 1], c=clusters, cmap='viridis')
    plt.title(f't-SNE visualization of {model_name} features')
    plt.xlabel('t-SNE component 1')
    plt.ylabel('t-SNE component 2')
    plt.legend(handles=scatter.legend_elements()[0], labels=range(4))

    # Add annotations
    for i, txt in enumerate(ids):
        plt.annotate(txt, (tsne_results[i, 0], tsne_results[i, 1]))

    plot_path = os.path.join(output_dir, f't_sne_{model_name.lower().replace(" ", "_")}.png')
    plt.savefig(plot_path)
    plt.close()
    print(f"Plot saved to {plot_path}")

In [None]:
# --- Part 1: 3D-ResNet34 ---
print("\n--- Starting 3D-ResNet34 Analysis ---")
if os.path.exists(resnet_weights_path):
    resnet_model = generate_resnet34(n_classes=1000, n_input_channels=3)
    print(f"Loading ResNet weights from {resnet_weights_path}...")
    state_dict = torch.load(resnet_weights_path, map_location=torch.device('cpu'))

    # Inflate 2D weights to 3D
    for name, param in resnet_model.named_parameters():
        if 'conv' in name and len(param.shape) == 5 and name in state_dict:
            param.data = state_dict[name].unsqueeze(2).repeat(1, 1, param.shape[2], 1, 1)

    resnet_extractor = ResNetFeatureExtractor(resnet_model)
    resnet_features, resnet_ids = extract_features_resnet(resnet_extractor, data_path)
    if resnet_features.size > 0:
        run_tsne_kmeans_and_plot(resnet_features, resnet_ids, "3D-ResNet34", output_dir)
    else:
        print("No features were extracted from 3D-ResNet34.")
else:
    print(f"Error: ResNet weights file not found at {resnet_weights_path}. Please upload it.")


--- Starting 3D-ResNet34 Analysis ---
Loading ResNet weights from /resnet-34-kinetics-cpu.pth...


RuntimeError: Cannot use ``weights_only=True`` with files saved in the legacy .tar format. In PyTorch 2.6, we changed the default value of the `weights_only` argument in `torch.load` from `False` to `True`. Re-running `torch.load` with `weights_only` set to `False` will likely succeed, but it can result in arbitrary code execution. Do it only if you got the file from a trusted source.

In [None]:
# --- Part 2: VideoMAE ---
print("\n--- Starting VideoMAE Analysis ---")
if VideoMAEModel is not None:
    print("Loading VideoMAE model and processor from Hugging Face...")
    videomae_processor = AutoImageProcessor.from_pretrained("MCG-NJU/videomae-base")
    videomae_model = VideoMAEModel.from_pretrained("MCG-NJU/videomae-base")
    videomae_features, videomae_ids = extract_features_videomae(videomae_model, videomae_processor, data_path)
    if videomae_features.size > 0:
        run_tsne_kmeans_and_plot(videomae_features, videomae_ids, "VideoMAE", output_dir)
    else:
        print("No features were extracted from VideoMAE.")
else:
    print("Skipping VideoMAE analysis because 'transformers' library is not installed.")

print("\n--- Pipeline Finished ---")

In [None]:
# Display Output Images

from IPython.display import Image, display

resnet_plot_path = './results/t_sne_3d-resnet34.png'
videomae_plot_path = './results/t_sne_videomae.png'

print("--- 3D-ResNet34 Feature Visualization ---")
if os.path.exists(resnet_plot_path):
    display(Image(filename=resnet_plot_path))
else:
    print("ResNet plot not found.")

print("\n--- VideoMAE Feature Visualization ---")
if os.path.exists(videomae_plot_path):
    display(Image(filename=videomae_plot_path))
else:
    print("VideoMAE plot not found.")