Skip to content

Commit

Permalink
Create basic CNN feature extractor
Browse files Browse the repository at this point in the history
  • Loading branch information
Cubevoid committed Mar 2, 2024
1 parent 70241ec commit 0a3b249
Showing 1 changed file with 18 additions and 0 deletions.
18 changes: 18 additions & 0 deletions model/feat_extractor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import torch

class FeatExtractor(torch.nn.Module):
def __init__(self):
super(FeatExtractor, self).__init__()
self.conv1 = torch.nn.Conv2d(3, 64, 3, 1, 1)
self.conv2 = torch.nn.Conv2d(64, 128, 3, 1, 1)
self.relu = torch.nn.ReLU()
self.maxpool = torch.nn.MaxPool2d(2, 2)

def forward(self, x: torch.Tensor):
"""
Args:
x: (B, C, H, W) input image tensor
"""
x = self.maxpool(self.relu(self.conv1(x)))
x = self.relu(self.conv2(x))
return x.flatten(1)

0 comments on commit 0a3b249

Please sign in to comment.