Skip to content

Commit

Permalink
Merge pull request #4 from IvLabs/embeddings
Browse files Browse the repository at this point in the history
Embeddings branch merge
  • Loading branch information
ABD-01 committed Jan 15, 2022
2 parents c43a7d4 + b78b8c6 commit ffb5c09
Show file tree
Hide file tree
Showing 15 changed files with 986 additions and 1,520 deletions.
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,13 @@ weights/
*/*.pyc
model/
wandb/
run1/
protoruns/
**/__pycache__/
**/**/__pycache__/
logs/
*.npy
centroids
*.pyc
model/
OfficeHomeDataset_10072016.zip
dalib/domainbed/__pycache__/*
1 change: 0 additions & 1 deletion OfficeHome_testrun/done

This file was deleted.

78 changes: 0 additions & 78 deletions OfficeHome_testrun/err.txt

This file was deleted.

475 changes: 0 additions & 475 deletions OfficeHome_testrun/out.txt

This file was deleted.

1 change: 0 additions & 1 deletion OfficeHome_testrun/proto_results.jsonl

This file was deleted.

2 changes: 0 additions & 2 deletions OfficeHome_testrun/results.jsonl

This file was deleted.

2 changes: 2 additions & 0 deletions common/vision/datasets/officehome.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,11 @@ class OfficeHome(ImageList):
'Fan', 'Ruler', 'Pan', 'Screwdriver', 'Trash_Can', 'Printer', 'Speaker', 'Eraser', 'Bucket', 'Chair',
'Calendar', 'Calculator', 'Flowers', 'Lamp_Shade', 'Spoon', 'Candles', 'Clipboards', 'Scissors', 'TV',
'Curtains', 'Fork', 'Soda', 'Table', 'Knives', 'Oven', 'Refrigerator', 'Marker']
CLASSES = sorted(CLASSES)

def __init__(self, root: str, task: str, download: Optional[bool] = False, **kwargs):
assert task in self.image_list
self.domain_index = sorted(list(self.image_list.keys())).index(task)
data_list_file = os.path.join(root, self.image_list[task])

if download:
Expand Down
6 changes: 3 additions & 3 deletions dalib/adaptation/degaa.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def __init__(self, input_dim: int, num_classes: int, gnn_layers: int, num_heads:

#self.final_proj = nn.Conv1d(self.input_dim, self.input_dim, kernel_size = 1, bias = True)
self.final_proj = nn.Linear(self.input_dim, self.input_dim)
self.head = nn.Linear(self.input_dim, self.num_classes)
# self.head = nn.Linear(self.input_dim, self.num_classes)

def forward(self, src, tgt):
#src = data['source'].double()
Expand All @@ -107,7 +107,7 @@ def forward(self, src, tgt):
src, tgt = self.gnn(src, tgt)

m_src, m_tgt = self.final_proj(src), self.final_proj(tgt)
y_src, y_tgt = self.head(m_src), self.head(m_tgt)
# y_src, y_tgt = self.head(m_src), self.head(m_tgt)

return m_src, m_tgt, y_src, y_tgt # To classification head
return m_src, m_tgt#, y_src, y_tgt # To classification head

Loading

0 comments on commit ffb5c09

Please sign in to comment.