-
Notifications
You must be signed in to change notification settings - Fork 0
/
Model3.py
29 lines (25 loc) · 894 Bytes
/
Model3.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
from torch import nn as nn
import torch
import torchvision
class RegressionModel(nn.Module):
def __init__(self):
super(RegressionModel,self).__init__()
resnet18 = torchvision.models.resnet18(pretrained = True)
self.model = nn.Sequential(*(list(resnet18.children())[:-1]),nn.Flatten())
# MC path
self.fc1 = nn.Linear(1,1)
self.fc2 = nn.Linear(1,1)
# Image path
self.fc3 = nn.Linear(512,1)
self.fc4 = nn.Linear(2,1)
def forward(self,image,mitotic_count):
# Image path
x1 = self.model(image)
x1 = nn.functional.sigmoid(self.fc3(x1))
# MC path
x2 = nn.functional.sigmoid(self.fc1(mitotic_count))
x2 = self.fc2(x2)
# Concatenate the image and the MC path
x = torch.cat((x1,x2), dim = 1)
y = self.fc4(x)
return y