Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

a more configureable resnet3d implementation #257

Merged
merged 4 commits into from Jan 22, 2023
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
195 changes: 140 additions & 55 deletions fuse/dl/models/backbones/backbone_resnet_3d.py
Expand Up @@ -17,78 +17,163 @@

"""

from typing import Tuple, Any

import torch.nn as nn
from torch import Tensor
from torch.hub import load_state_dict_from_url
from torchvision.models.video.resnet import VideoResNet, BasicBlock, Conv3DSimple, BasicStem, model_urls
from torchvision.models.video.resnet import model_urls
from typing import Tuple, Optional, Callable, List, Sequence, Type


class Conv3DSimple(nn.Conv3d):
def __init__(
self, in_planes: int, out_planes: int, midplanes: Optional[int] = None, stride: int = 1, padding: int = 1
) -> None:

super().__init__(
in_channels=in_planes,
out_channels=out_planes,
kernel_size=(3, 3, 3),
stride=stride,
padding=padding,
bias=False,
)

@staticmethod
def get_downsample_stride(stride: int) -> Tuple[int, int, int]:
return stride, stride, stride


class BasicBlock(nn.Module):

expansion = 1

def __init__(
self,
inplanes: int,
planes: int,
conv_builder: Callable[..., nn.Module],
stride: int = 1,
downsample: Optional[nn.Module] = None,
) -> None:
midplanes = (inplanes * planes * 3 * 3 * 3) // (inplanes * 3 * 3 + 3 * planes)

super().__init__()
self.conv1 = nn.Sequential(
conv_builder(inplanes, planes, midplanes, stride), nn.BatchNorm3d(planes), nn.ReLU(inplace=True)
)
self.conv2 = nn.Sequential(conv_builder(planes, planes, midplanes), nn.BatchNorm3d(planes))
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride

def forward(self, x: Tensor) -> Tensor:
residual = x

out = self.conv1(x)
out = self.conv2(out)
if self.downsample is not None:
residual = self.downsample(x)

out += residual
out = self.relu(out)

class BackboneResnet3D(VideoResNet):
return out


class BasicStem(nn.Sequential):
"""The default conv-batchnorm-relu stem"""

def __init__(self, in_channels=3, out_channels=64, kernel_size=(3, 7, 7), stride=(1, 2, 2)) -> None:
padding = tuple([x // 2 for x in kernel_size])
super().__init__(
nn.Conv3d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=False),
nn.BatchNorm3d(out_channels),
nn.ReLU(inplace=True),
)


class BackboneResnet3D(nn.Module):
"""
3D model classifier (ResNet architecture"
A slightly more configureable ResNet3D implementation.
Default values are identical to the pytorch implementation.
"""

def __init__(self, pretrained: bool = False, in_channels: int = 3, name: str = "r3d_18", pool=False) -> None:
"""
Create 3D ResNet model
:param pretrained: Use pretrained weights
:param in_channels: number of input channels
:param name: model name. currently only 'r3d_18' is supported
"""
# init parameters per required backbone
init_parameters = {
"r3d_18": {
"block": BasicBlock,
"conv_makers": [Conv3DSimple] * 4,
"layers": [2, 2, 2, 2],
"stem": BasicStem,
},
}[name]
# init original model
super().__init__(**init_parameters)

# load pretrained parameters if required
def __init__(
self,
*,
in_channels: int = 3,
pool: bool = True,
avihu111 marked this conversation as resolved.
Show resolved Hide resolved
layers: List[int] = [2, 2, 2, 2],
first_channel_dim: int = 64,
first_stride: int = 1,
stem_kernel_size: Sequence[int] = (3, 7, 7),
stem_stride: Sequence[int] = (1, 2, 2),
pretrained: bool = False,
name: str = "r3d_18",
avihu111 marked this conversation as resolved.
Show resolved Hide resolved
) -> None:
super().__init__()
self.inplanes = first_channel_dim
self.stem = BasicStem(in_channels, first_channel_dim, kernel_size=stem_kernel_size, stride=stem_stride)

self.layer1 = self._make_layer(BasicBlock, Conv3DSimple, first_channel_dim, layers[0], stride=first_stride)
self.layer2 = self._make_layer(BasicBlock, Conv3DSimple, first_channel_dim * 2, layers[1], stride=2)
self.layer3 = self._make_layer(BasicBlock, Conv3DSimple, first_channel_dim * 4, layers[2], stride=2)
self.layer4 = self._make_layer(BasicBlock, Conv3DSimple, first_channel_dim * 8, layers[3], stride=2)
self.out_dim = first_channel_dim * 8
self.avgpool = nn.AdaptiveAvgPool3d((1, 1, 1))
self._pool = pool

if pretrained:
state_dict = load_state_dict_from_url(model_urls[name])
self.load_state_dict(state_dict)
else:
for m in self.modules():
if isinstance(m, nn.Conv3d):
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm3d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
nn.init.constant_(m.bias, 0)

def _make_layer(
self,
block: Type[BasicBlock],
conv_builder: Type[Conv3DSimple],
planes: int,
blocks: int,
stride: int = 1,
) -> nn.Sequential:
downsample = None

if stride != 1 or self.inplanes != planes * block.expansion:
ds_stride = conv_builder.get_downsample_stride(stride)
downsample = nn.Sequential(
nn.Conv3d(self.inplanes, planes * block.expansion, kernel_size=1, stride=ds_stride, bias=False),
nn.BatchNorm3d(planes * block.expansion),
)
layers = []
layers.append(block(self.inplanes, planes, conv_builder, stride, downsample))

# save input parameters
self.pretrained = pretrained
self.in_channels = in_channels
self.inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(block(self.inplanes, planes, conv_builder))

if self.in_channels != 3:
# override the first convolution layer to support any number of input channels
self.stem = nn.Sequential(
nn.Conv3d(self.in_channels, 64, kernel_size=(3, 7, 7), stride=(1, 2, 2), padding=(1, 3, 3), bias=False),
nn.BatchNorm3d(64),
nn.ReLU(inplace=True),
)
self.pool = pool
self.gmp = nn.AdaptiveMaxPool3d(output_size=1)

def features(self, x: Tensor) -> Any:
"""
Extract spatial features - given a 3D tensor
:param x: Input tensor - shape: [batch_size, channels, z, y, x]
:return: spatial features - shape [batch_size, n_features, z', y', x']
"""
x = self.stem(x)
return nn.Sequential(*layers)

def features(self, x: Tensor) -> Tensor:
avihu111 marked this conversation as resolved.
Show resolved Hide resolved
x = self.stem(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
if hasattr(self, "pool") and self.pool:
x = self.gmp(x).flatten(1)
if self._pool:
x = self.avgpool(x)
x = x.flatten(1)
return x

def forward(self, x: Tensor) -> Tuple[Tensor, None, None, None]: # type: ignore
"""
Forward pass. 3D global classification given a volume
:param x: Input volume. shape: [batch_size, channels, z, y, x]
:return: logits for global classification. shape: [batch_size, n_classes].
"""
x = self.features(x)
return x
def forward(self, x: Tensor) -> Tensor:
return self.features(x)