-
Notifications
You must be signed in to change notification settings - Fork 3
/
pimodel.py
48 lines (36 loc) · 1.86 KB
/
pimodel.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import tensorflow as tf
from ..libml.data_augmentations import weak_augment, medium_augment, strong_augment
def pi_model(x, u, height, width):
"""
Applies medium augmentations on inputs x and u returns augmented tensors.
Args:
x: tensor, labeled batch of images of shape [batch, height, width, channels]
u: tensor, unlabeled batch of images of shape [batch, height, widht, channels]
height: int, height of images
width: int, width of images
Returns:
Augmented labeled tensor and two augmented unlabeled tensors.
"""
x_augment = medium_augment(x, height, width)
u_teacher = medium_augment(u, height, width)
u_student = medium_augment(u, height, width)
return x_augment, u_teacher, u_student
@tf.function
def ssl_loss_pi_model(labels_x, logits_x, logits_teacher, logits_student):
"""
Computes two cross entropy losses based on the labeled and unlabeled data.
loss_x is referring to the labeled CE loss and loss_u to the unlabeled CE loss.
Args:
labels_x: tensor, contains labels corresponding to logits_x of shape [batch, num_classes]
logits_x: tensor, contains the logits of an batch of images of shape [batch, num_classes]
logits_teacher: tensor, logits of teacher model of shape [batch, num_classes]
labels_student: tensor, logits of student model of shape [batch, num_classes]
Returns:
Two floating point numbers, the first representing the labeled CE loss
and the second holding the MSE loss values.
"""
x_loss = tf.nn.softmax_cross_entropy_with_logits(labels=labels_x, logits=logits_x)
x_loss = tf.reduce_mean(x_loss)
pm_loss = tf.reduce_mean((tf.nn.softmax(logits_teacher) - tf.nn.softmax(logits_student)) ** 2, -1)
pm_loss = tf.reduce_mean(pm_loss)
return x_loss, pm_loss