-
Notifications
You must be signed in to change notification settings - Fork 2
/
model.py
81 lines (61 loc) · 2.42 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from tqdm import tqdm
from torch.nn import Dropout, Linear
from torch.optim import Adam
from utils import get_feature_dis
class AE(nn.Module):
def __init__(self, n_hidden, n_input, n_z, dropout):
super(AE, self).__init__()
self.dropout = dropout
self.enc_1 = Linear(n_input, n_hidden)
self.z_layer = Linear(n_hidden, n_z)
self.dec_1 = Linear(n_z, n_hidden)
self.x_bar_layer = Linear(n_hidden, n_input)
self._init_weights()
def _init_weights(self):
nn.init.xavier_uniform_(self.enc_1.weight)
nn.init.xavier_uniform_(self.z_layer.weight)
nn.init.xavier_uniform_(self.dec_1.weight)
nn.init.xavier_uniform_(self.x_bar_layer.weight)
nn.init.normal_(self.enc_1.bias, std=1e-6)
nn.init.normal_(self.z_layer.bias, std=1e-6)
nn.init.normal_(self.dec_1.bias, std=1e-6)
nn.init.normal_(self.x_bar_layer.bias, std=1e-6)
def reset_parameters(self):
self.enc_1.reset_parameters()
self.z_layer.reset_parameters()
self.dec_1.reset_parameters()
self.x_bar_layer.reset_parameters()
def forward(self, x):
enc_h1 = F.relu(self.enc_1(x))
enc_h1 = F.dropout(enc_h1, p=self.dropout, training=self.training)
z = self.z_layer(enc_h1)
z_drop = F.dropout(z, p=self.dropout, training=self.training)
dec_h1 = F.relu(self.dec_1(z_drop))
dec_h1 = F.dropout(dec_h1, p=self.dropout, training=self.training)
x_bar = self.x_bar_layer(dec_h1)
return x_bar, z
class NE_WNA(nn.Module):
def __init__(self, nhid, n_z, nfeat, nclass, dropout):
super(NE_WNA, self).__init__()
self.AE = AE(nhid, nfeat, n_z, dropout)
self.classifier = Linear(n_z, nclass)
def reset_parameters(self):
self.AE.reset_parameters()
self.classifier.reset_parameters()
def forward(self, x):
x_bar, x = self.AE(x)
feature_cls = x
Z = x
x_dis = 0
if self.training:
x_dis = get_feature_dis(Z)
class_feature = self.classifier(feature_cls)
class_logits = F.log_softmax(class_feature, dim=1)
if self.training:
return x_bar, x_dis, class_logits
else:
return class_logits