In [1]:
import numpy as np 
import pandas as pd 
import matplotlib.pyplot as plt
import glob
from pathlib import Path
import torch.nn as nn
import torch.optim as optim
import torch
import torch.nn.functional as F

In [2]:
class BasicBlock1D(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1, kernel_size=7):
        super(BasicBlock1D, self).__init__()
        padding =  (kernel_size - 1) // 2
        self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size=kernel_size,
                               stride=stride, padding=padding, bias=False)
        self.bn1 = nn.BatchNorm1d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv1d(out_channels, out_channels, kernel_size=kernel_size,
                               stride=1, padding=padding, bias=False)
        self.bn2 = nn.BatchNorm1d(out_channels)
        self.downsample=False
        if in_channels != out_channels or stride != 1:
            self.skip = nn.Conv1d(in_channels, out_channels, 1, stride=stride)
            self.downsample=True

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample==True:
            identity = self.skip(identity)
        
        out += identity
        out = self.relu(out)

        return out

In [3]:
def output_conv_size(width, kernel, padding, stride):
    return (width - kernel + 2 * padding) / stride + 1

In [4]:

class ResNet1D(nn.Module):
    def __init__(self, block, layers, num_classes=5, input_channels=12, 
                 initial_filters=64, kernel_size=7):
        """
        1D ResNet for ECG classification
        
        Args:
            block: BasicBlock1D class
            layers: list of integers, number of blocks in each layer
            num_classes: number of output classes
            input_channels: number of input channels (12 for 12-lead ECG)
            initial_filters: number of filters in the first conv layer
            kernel_size: kernel size for convolutions
        """
        super(ResNet1D, self).__init__()
        self.in_channels = initial_filters
        self.kernel_size = kernel_size
        
        # Initial convolution layer
        self.conv1 = nn.Conv1d(input_channels, initial_filters, 
                              kernel_size=15, stride=2, padding=7, bias=False)
        self.bn1 = nn.BatchNorm1d(initial_filters)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool1d(kernel_size=3, stride=2, padding=1)
        
        # Residual layers
        self.layer1 = self._make_layer(block, initial_filters, layers[0], stride=1)
        self.layer2 = self._make_layer(block, initial_filters*2, layers[1], stride=2)
        self.layer3 = self._make_layer(block, initial_filters*4, layers[2], stride=2)
        self.layer4 = self._make_layer(block, initial_filters*8, layers[3], stride=2)
        
        # Global average pooling and classifier
        self.avgpool = nn.AdaptiveAvgPool1d(1)
        self.fc = nn.Linear(initial_filters*8, num_classes)
        
        # Initialize weights
        self._initialize_weights()
    
    def _make_layer(self, block, out_channels, blocks, stride=1):
        layers = []
        # First block may need downsampling
        layers.append(block(self.in_channels, out_channels, stride, self.kernel_size))
        self.in_channels = out_channels
        
        # Remaining blocks
        for _ in range(1, blocks):
            layers.append(block(out_channels, out_channels, 1, self.kernel_size))
        
        return nn.Sequential(*layers)
    
    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv1d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm1d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)
    
    def forward(self, x):
        # Initial layers
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        
        # Residual layers
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        
        # Global pooling and classification
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        
        return x


# Model factory functions
def resnet18_1d(num_classes=5, input_channels=12):
    """ResNet-18 1D variant"""
    return ResNet1D(BasicBlock1D, [2, 2, 2, 2], num_classes, input_channels)

def resnet34_1d(num_classes=5, input_channels=12):
    """ResNet-34 1D variant"""
    return ResNet1D(BasicBlock1D, [3, 4, 6, 3], num_classes, input_channels)

def resnet50_1d(num_classes=5, input_channels=12):
    """ResNet-50 1D variant (using BasicBlock instead of Bottleneck)"""
    return ResNet1D(BasicBlock1D, [3, 4, 6, 3], num_classes, input_channels, initial_filters=64)
