In [1]:
%run 'ML_imports.ipynb'



In [2]:
class Conv_Forward(nn.Module):
    def __init__(self, in_channels, out_channels, dropout):
        super(Conv_Forward, self).__init__()
        self.fc = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.LeakyReLU(),
            nn.BatchNorm2d(out_channels),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        return self.fc(x)

In [3]:
class Teacher(nn.Module):
    def __init__(self, in_channels, out_channels, num_conv, num_classes, dropout):
        super(Teacher, self).__init__()
        self.num_conv = num_conv
        self.conv_1 = nn.Conv2d(in_channels, out_channels, padding=1, kernel_size=3)
        self.activation = nn.LeakyReLU()
        self.drop_1 = nn.Dropout(dropout)
        self.fc = nn.Sequential(
            *[
                Conv_Forward(out_channels, out_channels, dropout=dropout)
                for i in range(num_conv)
            ]
        )
        self.classifier = nn.Sequential(
            nn.Linear(out_channels, 30),
            nn.LeakyReLU(),
            nn.Dropout(dropout),
            nn.Linear(30, num_classes)
        )

    def forward(self, x):
        out_1 = self.drop_1(self.activation(self.conv_1(x)))
        out_2 = self.fc(out_1)
        out_3 = torch.mean(out_2, dim=(2, 3))
        out_4 = self.classifier(out_3)
        return out_1, out_2, out_3, out_4

In [4]:
class Student(nn.Module):
    def __init__(self, in_channels, out_channels, num_classes, dropout):
        super(Student, self).__init__()
        self.conv_1 = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.LeakyReLU(),
            nn.Dropout()
        )
        self.conv_2 = nn.Sequential(
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.LeakyReLU(),
            nn.Dropout()
        )

        self.classifier = nn.Sequential(
            nn.Linear(out_channels, 30),
            nn.LeakyReLU(),
            nn.Dropout(dropout),
            nn.Linear(30, num_classes)
        )
        
    def forward(self, x, teacher_out_1=None, teacher_out_2=None, teachers_input_student_ratio=10):
        out_1 = self.conv_1(x)

        if np.random.rand() > 0.5:
            teacher_out_2 = None
        else:
            teacher_out_1 = None

        if (teacher_out_1 is not None) and (np.random.rand() > teachers_input_student_ratio):
            out_2 = self.conv_2(teacher_out_1.detach())
        else:
            out_2 = self.conv_2(out_1)
        
        if (teacher_out_2 is not None) and (np.random.rand() > teachers_input_student_ratio):
            out_3 = torch.mean(teacher_out_2.detach(), dim=(2, 3))
        else:
            out_3 = torch.mean(out_2, dim=(2, 3))

            
        out_4 = self.classifier(out_3)
        
        return out_1, out_2, out_3, out_4

In [30]:
class TeacherStudentNetwork(nn.Module):
    def __init__(
        self,
        in_channels,
        out_channels,
        teacher_num_conv,
        teacher_weights_path,
        num_classes,
        dropout,
    ):
        super(TeacherStudentNetwork, self).__init__()
        self.teacher =  Teacher(
            in_channels=in_channels,
            out_channels=out_channels,
            num_conv=teacher_num_conv,
            num_classes=num_classes,
            dropout=dropout
        )

        self.student = Student(
            in_channels=in_channels,
            out_channels=out_channels,
            num_classes=num_classes,
            dropout=dropout,
        )
        
        self.teacher.load_state_dict(torch.load(teacher_weights_path))

    def forward(self, x, teachers_input_student_ratio=10):
        with torch.no_grad():
            t_out_1, t_out_2, t_out_3, t_out_4 = self.teacher(x)
        s_out_1, s_out_2, s_out_3, s_out_4 = self.student(x, t_out_1, t_out_2, teachers_input_student_ratio)
        return (
            [t_out_1.detach(), t_out_2.detach(), t_out_3.detach(), t_out_4.detach()],
            [s_out_1, s_out_2, s_out_3, s_out_4]
        )

In [16]:
img = torch.randn(2, 3, 32, 32)
t = Teacher(in_channels = 3, out_channels = 32, num_conv = 4, num_classes = 10,dropout = 0.2)
s = Student(in_channels = 3, out_channels = 32, num_classes = 10, dropout = 0.2)
out_1, out_2, out_3, out_4 = t(img)

In [23]:
out_1.shape, out_2.shape, out_3.shape, out_4.shape

(torch.Size([2, 32, 32, 32]),
 torch.Size([2, 32, 32, 32]),
 torch.Size([2, 32]),
 torch.Size([2, 10]))

In [24]:
s_out_1, s_out_2, s_out_3, s_out_4 = s(img, out_1, out_2)

In [25]:
s_out_1.shape, s_out_2.shape, s_out_3.shape, s_out_4.shape

(torch.Size([2, 32, 32, 32]),
 torch.Size([2, 32, 32, 32]),
 torch.Size([2, 32]),
 torch.Size([2, 10]))

In [27]:
path = "model/teacher_model.bin"
torch.save(t.state_dict(), path)

In [31]:
t_s = TeacherStudentNetwork(3, 32, 4, path, 10, 0.2)