In [None]:
class Head(nn.Module):
    def __init__(self, in_dim, out_dim):
        super(Head, self).__init__()

        self.fc1 = nn.Linear(in_dim, 128)
        self.fc2 = nn.Linear(128, out_dim)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [None]:
class JPM(nn.Module):
    def __init__(self, in_dim, s_dim, t_dim, c_dim):
        super(JPM, self).__init__()
        z_dim = 256

        # Shared convolutional backbone ( Net_Z (X) )
        self.backbone = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=3, padding=1),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.MaxPool2d(2),

            nn.Conv2d(16, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(2),

            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2),

            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d((1, 1))

            nn.Flatten()
            nn.Linear(in_dim*128/64, z_dim)
        )

        # Cross predictions ( Net_A(Z) )
        self.indep_s = Head(z_dim, s_dim) # species
        self.indep_t = Head(z_dim, t_dim) # type
        self.indep_c = Head(z_dim, c_dim) # coordinate

        # Joint predictions 1 ( Net_B (A, Z) )
        self.joint_ts = Head(z_dim + t_dim, s_dim) # type -> species
        self.joint_cs = Head(z_dim + c_dim, s_dim) # coordinate -> species

        self.joint_st = Head(z_dim + s_dim, t_dim) # species -> type
        self.joint_ct = Head(z_dim + c_dim, t_dim) # coordinate -> type

        self.joint_sc = Head(z_dim + s_dim, c_dim) # species -> coordinate
        self.joint_tc = Head(z_dim + t_dim, c_dim) # type -> coordinate

        # Joint predictions 2 ( Net_C (A, B, Z) )
        self.joint_tcs = Head(z_dim + t_dim + c_dim, s_dim) # type, coordinate -> species
        self.joint_cts = Head(z_dim + c_dim + t_dim, s_dim) # coordinate, type -> species

        self.joint_sct = Head(z_dim + s_dim + c_dim, t_dim) # species, coordinate -> type
        self.joint_cst = Head(z_dim + c_dim + s_dim, t_dim) # coordinate, species -> type

        self.joint_stc = Head(z_dim + s_dim + t_dim, c_dim) # species, type -> coordinate
        self.joint_tsc = Head(z_dim + t_dim + s_dim, c_dim) # type, species -> coordinate



    def forward(self, x, y_s = None, y_t = None, y_c = None):
        # Shared backbone
        z = self.backbone(x)
        z = torch.flatten(z, 1)

        # Independent predictions
        s = self.indep_s(z)
        t = self.indep_t(z)
        c = self.indep_c(z)

        # For inference
        if y_s is None:
          y_s = s.detach()
          y_t = t.detach()
          y_c = c.detach()

        # concat s,t,c with z
        y_s = F.one_hot(y_s, num_classes=self.indep_s.net[-1].out_features).float()
        y_t = F.one_hot(y_t, num_classes=self.indep_t.net[-1].out_features).float()
        y_c = F.one_hot(y_c, num_classes=self.indep_c.net[-1].out_features).float()

        zs = torch.cat([z, y_s], dim=1)
        zt = torch.cat([z, y_t], dim=1)
        zc = torch.cat([z, y_c], dim=1)

        # Joint predictions 1
        ts = self.joint_ts(zt)
        cs = self.joint_cs(zc)
        st = self.joint_st(zs)
        ct = self.joint_ct(zc)
        sc = self.joint_sc(zs)
        tc = self.joint_tc(zt)

        # Inference...
        if y_s is None:
          y_s2 = ts.detach()
          y_s3 = cs.detach()
          y_t2 = st.detach()
          y_t3 = ct.detach()
          y_c2 = sc.detach()
          y_c3 = tc.detach()

        # concat s,t,c with s,t,c and z
        zts = torch.cat([z, y_t, y_s2], dim=1)
        zcs = torch.cat([z, y_c, y_s3], dim=1)
        zst = torch.cat([z, y_s, y_t2], dim=1)
        zct = torch.cat([z, y_c, y_t3], dim=1)
        zsc = torch.cat([z, y_s, y_c2], dim=1)
        ztc = torch.cat([z, y_t, y_c3], dim=1)

        # Joint predictions 2
        tcs = self.joint_tcs(zct)
        cts = self.joint_cts(zct)
        sct = self.joint_sct(zsc)
        cst = self.joint_cst(zsc)
        stc = self.joint_stc(zst)
        tsc = self.joint_tsc(zst)

        s_logits = [s, ts, cs, tcs, cts]
        t_logits = [t, st, ct, sct, cst]
        c_logits = [c, sc, tc, stc, tsc]

        return s_logits, t_logits, c_logits